===================================================================
@@ -736,3 +736,32 @@ func TestPtrToMapOfMap(t *testing.T) {
t.Fatalf("expected %v got %v", data, newData)
}
}
+
+// There was an error check comparing the length of the input with the
+// length of the slice being decoded. It was wrong because the next
+// thing in the input might be a type definition, which would lead to
+// an incorrect length check. This test reproduces the corner case.
+
+type Z struct {
+}
+
+func Test29ElementSlice(t *testing.T) {
+ Register(Z{})
+ src := make([]interface{}, 100) // Size needs to be bigger than size of type definition.
+ for i := range src {
+ src[i] = Z{}
+ }
+ buf := new(bytes.Buffer)
+ err := NewEncoder(buf).Encode(src)
+ if err != nil {
+ t.Fatalf("encode: %v", err)
+ return
+ }
+
+ var dst []interface{}
+ err = NewDecoder(buf).Decode(&dst)
+ if err != nil {
+ t.Errorf("decode: %v", err)
+ return
+ }
+}
===================================================================
@@ -562,6 +562,9 @@ func (dec *Decoder) ignoreSingle(engine
func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl error) {
instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
for i := 0; i < length; i++ {
+ if state.b.Len() == 0 {
+ errorf("decoding array or slice: length exceeds input size (%d elements)", length)
+ }
up := unsafe.Pointer(p)
if elemIndir > 1 {
up = decIndirect(up, elemIndir)
@@ -652,9 +655,6 @@ func (dec *Decoder) ignoreMap(state *dec
// Slices are encoded as an unsigned length followed by the elements.
func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl error) {
nr := state.decodeUint()
- if nr > uint64(state.b.Len()) {
- errorf("length of slice exceeds input size (%d elements)", nr)
- }
n := int(nr)
if indir > 0 {
up := unsafe.Pointer(p)
===================================================================
@@ -80,6 +80,7 @@ package testing
import (
_ "debug/elf"
+ "bytes"
"flag"
"fmt"
"os"
@@ -87,6 +88,7 @@ import (
"runtime/pprof"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -116,8 +118,10 @@ var (
// common holds the elements common between T and B and
// captures common methods such as Errorf.
type common struct {
- output []byte // Output generated by test or benchmark.
- failed bool // Test or benchmark has failed.
+ mu sync.RWMutex // guards output and failed
+ output []byte // Output generated by test or benchmark.
+ failed bool // Test or benchmark has failed.
+
start time.Time // Time test or benchmark started
duration time.Duration
self interface{} // To be sent on signal channel when done.
@@ -129,37 +133,42 @@ func Short() bool {
return *short
}
-// decorate inserts the final newline if needed and indentation tabs for formatting.
-// If addFileLine is true, it also prefixes the string with the file and line of the call site.
-func decorate(s string, addFileLine bool) string {
- if addFileLine {
- _, file, line, ok := runtime.Caller(3) // decorate + log + public function.
- if ok {
- // Truncate file name at last file name separator.
- if index := strings.LastIndex(file, "/"); index >= 0 {
- file = file[index+1:]
- } else if index = strings.LastIndex(file, "\\"); index >= 0 {
- file = file[index+1:]
- }
- } else {
- file = "???"
- line = 1
- }
- s = fmt.Sprintf("%s:%d: %s", file, line, s)
- }
- s = "\t" + s // Every line is indented at least one tab.
- n := len(s)
- if n > 0 && s[n-1] != '\n' {
- s += "\n"
- n++
+// decorate prefixes the string with the file and line of the call site
+// and inserts the final newline if needed and indentation tabs for formatting.
+func decorate(s string) string {
+ _, file, line, ok := runtime.Caller(3) // decorate + log + public function.
+ if ok {
+ // Truncate file name at last file name separator.
+ if index := strings.LastIndex(file, "/"); index >= 0 {
+ file = file[index+1:]
+ } else if index = strings.LastIndex(file, "\\"); index >= 0 {
+ file = file[index+1:]
+ }
+ } else {
+ file = "???"
+ line = 1
}
- for i := 0; i < n-1; i++ { // -1 to avoid final newline
- if s[i] == '\n' {
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "%s:%d: ", file, line)
+
+ lines := strings.Split(s, "\n")
+ for i, line := range lines {
+ if i > 0 {
+ buf.WriteByte('\n')
+ }
+ // Every line is indented at least one tab.
+ buf.WriteByte('\t')
+ if i > 0 {
// Second and subsequent lines are indented an extra tab.
- return s[0:i+1] + "\t" + decorate(s[i+1:n], false)
+ buf.WriteByte('\t')
}
+ buf.WriteString(line)
+ }
+ if l := len(s); l > 0 && s[len(s)-1] != '\n' {
+ // Add final new line if needed.
+ buf.WriteByte('\n')
}
- return s
+ return buf.String()
}
// T is a type passed to Test functions to manage test state and support formatted test logs.
@@ -171,10 +180,18 @@ type T struct {
}
// Fail marks the function as having failed but continues execution.
-func (c *common) Fail() { c.failed = true }
+func (c *common) Fail() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.failed = true
+}
// Failed returns whether the function has failed.
-func (c *common) Failed() bool { return c.failed }
+func (c *common) Failed() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return c.failed
+}
// FailNow marks the function as having failed and stops its execution.
// Execution will continue at the next test or benchmark.
@@ -205,7 +222,9 @@ func (c *common) FailNow() {
// log generates the output. It's always at the same stack depth.
func (c *common) log(s string) {
- c.output = append(c.output, decorate(s, true)...)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.output = append(c.output, decorate(s)...)
}
// Log formats its arguments using default formatting, analogous to Println(),
@@ -298,7 +317,7 @@ func Main(matchString func(pat, str stri
func (t *T) report() {
tstr := fmt.Sprintf("(%.2f seconds)", t.duration.Seconds())
format := "--- %s: %s %s\n%s"
- if t.failed {
+ if t.Failed() {
fmt.Printf(format, "FAIL", t.name, tstr, t.output)
} else if *chatty {
fmt.Printf(format, "PASS", t.name, tstr, t.output)
@@ -357,7 +376,7 @@ func RunTests(matchString func(pat, str
continue
}
t.report()
- ok = ok && !out.failed
+ ok = ok && !out.Failed()
}
running := 0
@@ -370,7 +389,7 @@ func RunTests(matchString func(pat, str
}
t := (<-collector).(*T)
t.report()
- ok = ok && !t.failed
+ ok = ok && !t.Failed()
running--
}
}
===================================================================
@@ -8,6 +8,7 @@ package http_test
import (
"crypto/tls"
+ "crypto/x509"
"errors"
"fmt"
"io"
@@ -465,3 +466,49 @@ func TestClientErrorWithRequestURI(t *te
t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
}
}
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+ certs := x509.NewCertPool()
+ for _, c := range ts.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return &Transport{
+ TLSClientConfig: &tls.Config{RootCAs: certs},
+ }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS.ServerName != "127.0.0.1" {
+ t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: newTLSTransport(t, ts)}
+ if _, err := c.Get(ts.URL); err != nil {
+ t.Fatalf("expected successful TLS connection, got error: %v", err)
+ }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ trans := newTLSTransport(t, ts)
+ trans.TLSClientConfig.ServerName = "badserver"
+ c := &Client{Transport: trans}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+ t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+ }
+}
===================================================================
@@ -168,6 +168,9 @@ var vtests = []struct {
{"http://someHost.com/someDir/apage", "someHost.com/someDir"},
{"http://otherHost.com/someDir/apage", "someDir"},
{"http://otherHost.com/aDir/apage", "Default"},
+ // redirections for trees
+ {"http://localhost/someDir", "/someDir/"},
+ {"http://someHost.com/someDir", "/someDir/"},
}
func TestHostHandlers(t *testing.T) {
@@ -199,9 +202,19 @@ func TestHostHandlers(t *testing.T) {
t.Errorf("reading response: %v", err)
continue
}
- s := r.Header.Get("Result")
- if s != vt.expected {
- t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ switch r.StatusCode {
+ case StatusOK:
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ case StatusMovedPermanently:
+ s := r.Header.Get("Location")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ default:
+ t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
}
}
}
@@ -1108,6 +1121,38 @@ func TestServerBufferedChunking(t *testi
}
}
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
+ defer ts.Close()
+
+ for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+ req, _ := NewRequest("GET", "/", nil)
+ res, err := ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("error reading response: %v", err)
+ }
+ if te := res.TransferEncoding; len(te) > 0 {
+ t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+ }
+ if cl := res.ContentLength; cl != 0 {
+ t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+ }
+ conn.Close()
+ }
+}
+
// goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, d time.Duration, f func()) {
ch := make(chan bool, 2)
===================================================================
@@ -14,15 +14,34 @@ var ParseRangeTests = []struct {
r []httpRange
}{
{"", 0, nil},
+ {"", 1000, nil},
{"foo", 0, nil},
{"bytes=", 0, nil},
+ {"bytes=7", 10, nil},
+ {"bytes= 7 ", 10, nil},
+ {"bytes=1-", 0, nil},
{"bytes=5-4", 10, nil},
{"bytes=0-2,5-4", 10, nil},
+ {"bytes=2-5,4-3", 10, nil},
+ {"bytes=--5,4--3", 10, nil},
+ {"bytes=A-", 10, nil},
+ {"bytes=A- ", 10, nil},
+ {"bytes=A-Z", 10, nil},
+ {"bytes= -Z", 10, nil},
+ {"bytes=5-Z", 10, nil},
+ {"bytes=Ran-dom, garbage", 10, nil},
+ {"bytes=0x01-0x02", 10, nil},
+ {"bytes= ", 10, nil},
+ {"bytes= , , , ", 10, nil},
+
{"bytes=0-9", 10, []httpRange{{0, 10}}},
{"bytes=0-", 10, []httpRange{{0, 10}}},
{"bytes=5-", 10, []httpRange{{5, 5}}},
{"bytes=0-20", 10, []httpRange{{0, 10}}},
{"bytes=15-,0-5", 10, nil},
+ {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+ {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+ {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
{"bytes=-5", 10, []httpRange{{5, 5}}},
{"bytes=-15", 10, []httpRange{{0, 10}}},
{"bytes=0-499", 10000, []httpRange{{0, 500}}},
@@ -32,6 +51,9 @@ var ParseRangeTests = []struct {
{"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
{"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
{"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+ // Match Apache laxity:
+ {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
}
func TestParseRange(t *testing.T) {
===================================================================
@@ -508,8 +508,16 @@ func (w *response) Write(data []byte) (n
}
func (w *response) finishRequest() {
- // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
- // back, we can make this a keep-alive response ...
+ // If the handler never wrote any bytes and never sent a Content-Length
+ // response header, set the length explicitly to zero. This helps
+ // HTTP/1.0 clients keep their "keep-alive" connections alive, and for
+ // HTTP/1.1 clients is just as good as the alternative: sending a
+ // chunked response and immediately sending the zero-length EOF chunk.
+ if w.written == 0 && w.header.Get("Content-Length") == "" {
+ w.header.Set("Content-Length", "0")
+ }
+ // If this was an HTTP/1.0 request with keep-alive and we sent a
+ // Content-Length back, we can make this a keep-alive response ...
if w.req.wantsHttp10KeepAlive() {
sentLength := w.header.Get("Content-Length") != ""
if sentLength && w.header.Get("Connection") == "keep-alive" {
@@ -836,13 +844,15 @@ func RedirectHandler(url string, code in
// redirecting any request containing . or .. elements to an
// equivalent .- and ..-free URL.
type ServeMux struct {
- mu sync.RWMutex
- m map[string]muxEntry
+ mu sync.RWMutex
+ m map[string]muxEntry
+ hosts bool // whether any patterns contain hostnames
}
type muxEntry struct {
explicit bool
h Handler
+ pattern string
}
// NewServeMux allocates and returns a new ServeMux.
@@ -883,8 +893,7 @@ func cleanPath(p string) string {
// Find a handler on a handler map given a path string
// Most-specific (longest) pattern wins
-func (mux *ServeMux) match(path string) Handler {
- var h Handler
+func (mux *ServeMux) match(path string) (h Handler, pattern string) {
var n = 0
for k, v := range mux.m {
if !pathMatch(k, path) {
@@ -893,37 +902,59 @@ func (mux *ServeMux) match(path string)
if h == nil || len(k) > n {
n = len(k)
h = v.h
+ pattern = v.pattern
+ }
+ }
+ return
+}
+
+// Handler returns the handler to use for the given request,
+// consulting r.Method, r.Host, and r.URL.Path. It always returns
+// a non-nil handler. If the path is not in its canonical form, the
+// handler will be an internally-generated handler that redirects
+// to the canonical path.
+//
+// Handler also returns the registered pattern that matches the
+// request or, in the case of internally-generated redirects,
+// the pattern that will match after following the redirect.
+//
+// If there is no registered handler that applies to the request,
+// Handler returns a ``page not found'' handler and an empty pattern.
+func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
+ if r.Method != "CONNECT" {
+ if p := cleanPath(r.URL.Path); p != r.URL.Path {
+ _, pattern = mux.handler(r.Host, p)
+ return RedirectHandler(p, StatusMovedPermanently), pattern
}
}
- return h
+
+ return mux.handler(r.Host, r.URL.Path)
}
-// handler returns the handler to use for the request r.
-func (mux *ServeMux) handler(r *Request) Handler {
+// handler is the main implementation of Handler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
mux.mu.RLock()
defer mux.mu.RUnlock()
// Host-specific pattern takes precedence over generic ones
- h := mux.match(r.Host + r.URL.Path)
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
if h == nil {
- h = mux.match(r.URL.Path)
+ h, pattern = mux.match(path)
}
if h == nil {
- h = NotFoundHandler()
+ h, pattern = NotFoundHandler(), ""
}
- return h
+ return
}
// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
- // Clean path to canonical form and redirect.
- if p := cleanPath(r.URL.Path); p != r.URL.Path {
- w.Header().Set("Location", p)
- w.WriteHeader(StatusMovedPermanently)
- return
- }
- mux.handler(r).ServeHTTP(w, r)
+ h, _ := mux.Handler(r)
+ h.ServeHTTP(w, r)
}
// Handle registers the handler for the given pattern.
@@ -942,14 +973,26 @@ func (mux *ServeMux) Handle(pattern stri
panic("http: multiple registrations for " + pattern)
}
- mux.m[pattern] = muxEntry{explicit: true, h: handler}
+ mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern}
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
// Helpful behavior:
// If pattern is /tree/, insert an implicit permanent redirect for /tree.
// It can be overridden by an explicit registration.
n := len(pattern)
if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit {
- mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)}
+ // If pattern contains a host name, strip it and use remaining
+ // path for redirect.
+ path := pattern
+ if pattern[0] != '/' {
+ // In pattern, at least the last character is a '/', so
+ // strings.Index can't be -1.
+ path = pattern[strings.Index(pattern, "/"):]
+ }
+ mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern}
}
}
===================================================================
@@ -10,6 +10,8 @@ import (
"fmt"
"io"
"io/ioutil"
+ "mime"
+ "mime/multipart"
"net"
. "net/http"
"net/http/httptest"
@@ -25,21 +27,29 @@ import (
)
const (
- testFile = "testdata/file"
- testFileLength = 11
+ testFile = "testdata/file"
+ testFileLen = 11
)
+type wantRange struct {
+ start, end int64 // range [start,end)
+}
+
var ServeFileRangeTests = []struct {
- start, end int
- r string
- code int
+ r string
+ code int
+ ranges []wantRange
}{
- {0, testFileLength, "", StatusOK},
- {0, 5, "0-4", StatusPartialContent},
- {2, testFileLength, "2-", StatusPartialContent},
- {testFileLength - 5, testFileLength, "-5", StatusPartialContent},
- {3, 8, "3-7", StatusPartialContent},
- {0, 0, "20-", StatusRequestedRangeNotSatisfiable},
+ {r: "", code: StatusOK},
+ {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+ {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+ {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+ {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+ {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+ {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+ {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+ {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
}
func TestServeFile(t *testing.T) {
@@ -65,33 +75,81 @@ func TestServeFile(t *testing.T) {
// straight GET
_, body := getBody(t, "straight get", req)
- if !equal(body, file) {
+ if !bytes.Equal(body, file) {
t.Fatalf("body mismatch: got %q, want %q", body, file)
}
// Range tests
- for i, rt := range ServeFileRangeTests {
- req.Header.Set("Range", "bytes="+rt.r)
- if rt.r == "" {
- req.Header["Range"] = nil
- }
- r, body := getBody(t, fmt.Sprintf("test %d", i), req)
- if r.StatusCode != rt.code {
- t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code)
+ for _, rt := range ServeFileRangeTests {
+ if rt.r != "" {
+ req.Header.Set("Range", rt.r)
+ }
+ resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+ if resp.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
}
if rt.code == StatusRequestedRangeNotSatisfiable {
continue
}
- h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength)
- if rt.r == "" {
- h = ""
- }
- cr := r.Header.Get("Content-Range")
- if cr != h {
- t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
- }
- if !equal(body, file[rt.start:rt.end]) {
- t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
+ wantContentRange := ""
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ }
+ cr := resp.Header.Get("Content-Range")
+ if cr != wantContentRange {
+ t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
+ }
+ ct := resp.Header.Get("Content-Type")
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if strings.HasPrefix(ct, "multipart/byteranges") {
+ t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r)
+ }
+ }
+ if len(rt.ranges) > 1 {
+ typ, params, err := mime.ParseMediaType(ct)
+ if err != nil {
+ t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+ continue
+ }
+ if typ != "multipart/byteranges" {
+ t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r)
+ continue
+ }
+ if params["boundary"] == "" {
+ t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+ }
+ if g, w := resp.ContentLength, int64(len(body)); g != w {
+ t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+ }
+ mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+ for ri, rng := range rt.ranges {
+ part, err := mr.NextPart()
+ if err != nil {
+ t.Fatalf("range=%q, reading part index %d: %v", rt.r, ri, err)
+ }
+ body, err := ioutil.ReadAll(part)
+ if err != nil {
+ t.Fatalf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+ }
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+ t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+ }
+ }
+ _, err = mr.NextPart()
+ if err != io.EOF {
+ t.Errorf("range=%q; expected final error io.EOF; got %v", err)
+ }
}
}
}
@@ -464,15 +522,3 @@ func TestLinuxSendfileChild(*testing.T)
panic(err)
}
}
-
-func equal(a, b []byte) bool {
- if len(a) != len(b) {
- return false
- }
- for i := range a {
- if a[i] != b[i] {
- return false
- }
- }
- return true
-}
===================================================================
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"mime"
+ "mime/multipart"
+ "net/textproto"
"os"
"path"
"path/filepath"
@@ -123,8 +125,9 @@ func serveContent(w ResponseWriter, r *R
code := StatusOK
// If Content-Type isn't set, use the file's extension to find it.
- if w.Header().Get("Content-Type") == "" {
- ctype := mime.TypeByExtension(filepath.Ext(name))
+ ctype := w.Header().Get("Content-Type")
+ if ctype == "" {
+ ctype = mime.TypeByExtension(filepath.Ext(name))
if ctype == "" {
// read a chunk to decide between utf-8 text and binary
var buf [1024]byte
@@ -141,18 +144,34 @@ func serveContent(w ResponseWriter, r *R
}
// handle Content-Range header.
- // TODO(adg): handle multiple ranges
sendSize := size
+ var sendContent io.Reader = content
if size >= 0 {
ranges, err := parseRange(r.Header.Get("Range"), size)
- if err == nil && len(ranges) > 1 {
- err = errors.New("multiple ranges not supported")
- }
if err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
- if len(ranges) == 1 {
+ if sumRangesSize(ranges) >= size {
+ // The total number of bytes in all the ranges
+ // is larger the the size of the file by
+ // itself, so this is probably an attack, or a
+ // dumb client. Ignore the range request.
+ ranges = nil
+ }
+ switch {
+ case len(ranges) == 1:
+ // RFC 2616, Section 14.16:
+ // "When an HTTP message includes the content of a single
+ // range (for example, a response to a request for a
+ // single range, or to a request for a set of ranges
+ // that overlap without any holes), this content is
+ // transmitted with a Content-Range header, and a
+ // Content-Length header showing the number of bytes
+ // actually transferred.
+ // ...
+ // A response to a request for a single range MUST NOT
+ // be sent using the multipart/byteranges media type."
ra := ranges[0]
if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
@@ -160,7 +179,41 @@ func serveContent(w ResponseWriter, r *R
}
sendSize = ra.length
code = StatusPartialContent
- w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size))
+ w.Header().Set("Content-Range", ra.contentRange(size))
+ case len(ranges) > 1:
+ for _, ra := range ranges {
+ if ra.start > size {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ }
+ sendSize = rangesMIMESize(ranges, ctype, size)
+ code = StatusPartialContent
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+ w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+ sendContent = pr
+ defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+ go func() {
+ for _, ra := range ranges {
+ part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+ if err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := io.CopyN(part, content, ra.length); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ }
+ mw.Close()
+ pw.Close()
+ }()
}
w.Header().Set("Accept-Ranges", "bytes")
@@ -172,11 +225,7 @@ func serveContent(w ResponseWriter, r *R
w.WriteHeader(code)
if r.Method != "HEAD" {
- if sendSize == -1 {
- io.Copy(w, content)
- } else {
- io.CopyN(w, content, sendSize)
- }
+ io.CopyN(w, sendContent, sendSize)
}
}
@@ -312,6 +361,17 @@ type httpRange struct {
start, length int64
}
+func (r httpRange) contentRange(size int64) string {
+ return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+ return textproto.MIMEHeader{
+ "Content-Range": {r.contentRange(size)},
+ "Content-Type": {contentType},
+ }
+}
+
// parseRange parses a Range header string as per RFC 2616.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
@@ -323,11 +383,15 @@ func parseRange(s string, size int64) ([
}
var ranges []httpRange
for _, ra := range strings.Split(s[len(b):], ",") {
+ ra = strings.TrimSpace(ra)
+ if ra == "" {
+ continue
+ }
i := strings.Index(ra, "-")
if i < 0 {
return nil, errors.New("invalid range")
}
- start, end := ra[:i], ra[i+1:]
+ start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
var r httpRange
if start == "" {
// If no start is specified, end specifies the
@@ -365,3 +429,32 @@ func parseRange(s string, size int64) ([
}
return ranges, nil
}
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+ *w += countingWriter(len(p))
+ return len(p), nil
+}
+
+// rangesMIMESize returns the nunber of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+ var w countingWriter
+ mw := multipart.NewWriter(&w)
+ for _, ra := range ranges {
+ mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+ encSize += ra.length
+ }
+ mw.Close()
+ encSize += int64(w)
+ return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+ for _, ra := range ranges {
+ size += ra.length
+ }
+ return
+}
===================================================================
@@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w h
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time).
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
+MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
-8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
-DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
-CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
-+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
------END CERTIFICATE-----
-`)
+8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB
+/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC
+CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB
+nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw
+Pg==
+-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
===================================================================
@@ -365,7 +365,18 @@ func (t *Transport) getConn(cm *connectM
if cm.targetScheme == "https" {
// Initiate TLS and check remote host name against certificate.
- conn = tls.Client(conn, t.TLSClientConfig)
+ cfg := t.TLSClientConfig
+ if cfg == nil || cfg.ServerName == "" {
+ host, _, _ := net.SplitHostPort(cm.addr())
+ if cfg == nil {
+ cfg = &tls.Config{ServerName: host}
+ } else {
+ clone := *cfg // shallow clone
+ clone.ServerName = host
+ cfg = &clone
+ }
+ }
+ conn = tls.Client(conn, cfg)
if err = conn.(*tls.Conn).Handshake(); err != nil {
return nil, err
}
===================================================================
@@ -136,7 +136,7 @@ func NewPackage(fset *token.FileSet, fil
for _, obj := range pkg.Data.(*Scope).Objects {
p.declare(fileScope, pkgScope, obj)
}
- } else {
+ } else if name != "_" {
// declare imported package object in file scope
// (do not re-use pkg in the file scope but create
// a new object instead; the Decl field is different