Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "mime/multipart"
24 "net"
25 . "net/http"
26 "net/http/httptest"
27 "net/http/httptrace"
28 "net/http/httputil"
29 "net/http/internal"
30 "net/http/internal/testcert"
31 "net/url"
32 "os"
33 "path/filepath"
34 "reflect"
35 "regexp"
36 "runtime"
37 "strconv"
38 "strings"
39 "sync"
40 "sync/atomic"
41 "syscall"
42 "testing"
43 "time"
44 )
45
46 type dummyAddr string
47 type oneConnListener struct {
48 conn net.Conn
49 }
50
51 func (l *oneConnListener) Accept() (c net.Conn, err error) {
52 c = l.conn
53 if c == nil {
54 err = io.EOF
55 return
56 }
57 err = nil
58 l.conn = nil
59 return
60 }
61
62 func (l *oneConnListener) Close() error {
63 return nil
64 }
65
66 func (l *oneConnListener) Addr() net.Addr {
67 return dummyAddr("test-address")
68 }
69
70 func (a dummyAddr) Network() string {
71 return string(a)
72 }
73
74 func (a dummyAddr) String() string {
75 return string(a)
76 }
77
78 type noopConn struct{}
79
80 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
81 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
82 func (noopConn) SetDeadline(t time.Time) error { return nil }
83 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
84 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
85
86 type rwTestConn struct {
87 io.Reader
88 io.Writer
89 noopConn
90
91 closeFunc func() error
92 closec chan bool
93 }
94
95 func (c *rwTestConn) Close() error {
96 if c.closeFunc != nil {
97 return c.closeFunc()
98 }
99 select {
100 case c.closec <- true:
101 default:
102 }
103 return nil
104 }
105
106 type testConn struct {
107 readMu sync.Mutex
108 readBuf bytes.Buffer
109 writeBuf bytes.Buffer
110 closec chan bool
111 noopConn
112 }
113
114 func newTestConn() *testConn {
115 return &testConn{closec: make(chan bool, 1)}
116 }
117
118 func (c *testConn) Read(b []byte) (int, error) {
119 c.readMu.Lock()
120 defer c.readMu.Unlock()
121 return c.readBuf.Read(b)
122 }
123
124 func (c *testConn) Write(b []byte) (int, error) {
125 return c.writeBuf.Write(b)
126 }
127
128 func (c *testConn) Close() error {
129 select {
130 case c.closec <- true:
131 default:
132 }
133 return nil
134 }
135
136
137
138 func reqBytes(req string) []byte {
139 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
140 }
141
142 type handlerTest struct {
143 logbuf bytes.Buffer
144 handler Handler
145 }
146
147 func newHandlerTest(h Handler) handlerTest {
148 return handlerTest{handler: h}
149 }
150
151 func (ht *handlerTest) rawResponse(req string) string {
152 reqb := reqBytes(req)
153 var output strings.Builder
154 conn := &rwTestConn{
155 Reader: bytes.NewReader(reqb),
156 Writer: &output,
157 closec: make(chan bool, 1),
158 }
159 ln := &oneConnListener{conn: conn}
160 srv := &Server{
161 ErrorLog: log.New(&ht.logbuf, "", 0),
162 Handler: ht.handler,
163 }
164 go srv.Serve(ln)
165 <-conn.closec
166 return output.String()
167 }
168
169 func TestConsumingBodyOnNextConn(t *testing.T) {
170 t.Parallel()
171 defer afterTest(t)
172 conn := new(testConn)
173 for i := 0; i < 2; i++ {
174 conn.readBuf.Write([]byte(
175 "POST / HTTP/1.1\r\n" +
176 "Host: test\r\n" +
177 "Content-Length: 11\r\n" +
178 "\r\n" +
179 "foo=1&bar=1"))
180 }
181
182 reqNum := 0
183 ch := make(chan *Request)
184 servech := make(chan error)
185 listener := &oneConnListener{conn}
186 handler := func(res ResponseWriter, req *Request) {
187 reqNum++
188 ch <- req
189 }
190
191 go func() {
192 servech <- Serve(listener, HandlerFunc(handler))
193 }()
194
195 var req *Request
196 req = <-ch
197 if req == nil {
198 t.Fatal("Got nil first request.")
199 }
200 if req.Method != "POST" {
201 t.Errorf("For request #1's method, got %q; expected %q",
202 req.Method, "POST")
203 }
204
205 req = <-ch
206 if req == nil {
207 t.Fatal("Got nil first request.")
208 }
209 if req.Method != "POST" {
210 t.Errorf("For request #2's method, got %q; expected %q",
211 req.Method, "POST")
212 }
213
214 if serveerr := <-servech; serveerr != io.EOF {
215 t.Errorf("Serve returned %q; expected EOF", serveerr)
216 }
217 }
218
219 type stringHandler string
220
221 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
222 w.Header().Set("Result", string(s))
223 }
224
225 var handlers = []struct {
226 pattern string
227 msg string
228 }{
229 {"/", "Default"},
230 {"/someDir/", "someDir"},
231 {"/#/", "hash"},
232 {"someHost.com/someDir/", "someHost.com/someDir"},
233 }
234
235 var vtests = []struct {
236 url string
237 expected string
238 }{
239 {"http://localhost/someDir/apage", "someDir"},
240 {"http://localhost/%23/apage", "hash"},
241 {"http://localhost/otherDir/apage", "Default"},
242 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
243 {"http://otherHost.com/someDir/apage", "someDir"},
244 {"http://otherHost.com/aDir/apage", "Default"},
245
246 {"http://localhost/someDir", "/someDir/"},
247 {"http://localhost/%23", "/%23/"},
248 {"http://someHost.com/someDir", "/someDir/"},
249 }
250
251 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
252 func testHostHandlers(t *testing.T, mode testMode) {
253 mux := NewServeMux()
254 for _, h := range handlers {
255 mux.Handle(h.pattern, stringHandler(h.msg))
256 }
257 ts := newClientServerTest(t, mode, mux).ts
258
259 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
260 if err != nil {
261 t.Fatal(err)
262 }
263 defer conn.Close()
264 cc := httputil.NewClientConn(conn, nil)
265 for _, vt := range vtests {
266 var r *Response
267 var req Request
268 if req.URL, err = url.Parse(vt.url); err != nil {
269 t.Errorf("cannot parse url: %v", err)
270 continue
271 }
272 if err := cc.Write(&req); err != nil {
273 t.Errorf("writing request: %v", err)
274 continue
275 }
276 r, err := cc.Read(&req)
277 if err != nil {
278 t.Errorf("reading response: %v", err)
279 continue
280 }
281 switch r.StatusCode {
282 case StatusOK:
283 s := r.Header.Get("Result")
284 if s != vt.expected {
285 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
286 }
287 case StatusMovedPermanently:
288 s := r.Header.Get("Location")
289 if s != vt.expected {
290 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
291 }
292 default:
293 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
294 }
295 }
296 }
297
298 var serveMuxRegister = []struct {
299 pattern string
300 h Handler
301 }{
302 {"/dir/", serve(200)},
303 {"/search", serve(201)},
304 {"codesearch.google.com/search", serve(202)},
305 {"codesearch.google.com/", serve(203)},
306 {"example.com/", HandlerFunc(checkQueryStringHandler)},
307 }
308
309
310 func serve(code int) HandlerFunc {
311 return func(w ResponseWriter, r *Request) {
312 w.WriteHeader(code)
313 }
314 }
315
316
317
318
319 func checkQueryStringHandler(w ResponseWriter, r *Request) {
320 u := *r.URL
321 u.Scheme = "http"
322 u.Host = r.Host
323 u.RawQuery = ""
324 if "http://"+r.URL.RawQuery == u.String() {
325 w.WriteHeader(200)
326 } else {
327 w.WriteHeader(500)
328 }
329 }
330
331 var serveMuxTests = []struct {
332 method string
333 host string
334 path string
335 code int
336 pattern string
337 }{
338 {"GET", "google.com", "/", 404, ""},
339 {"GET", "google.com", "/dir", 301, "/dir/"},
340 {"GET", "google.com", "/dir/", 200, "/dir/"},
341 {"GET", "google.com", "/dir/file", 200, "/dir/"},
342 {"GET", "google.com", "/search", 201, "/search"},
343 {"GET", "google.com", "/search/", 404, ""},
344 {"GET", "google.com", "/search/foo", 404, ""},
345 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
346 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
347 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
348 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
349 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
350 {"GET", "images.google.com", "/search", 201, "/search"},
351 {"GET", "images.google.com", "/search/", 404, ""},
352 {"GET", "images.google.com", "/search/foo", 404, ""},
353 {"GET", "google.com", "/../search", 301, "/search"},
354 {"GET", "google.com", "/dir/..", 301, ""},
355 {"GET", "google.com", "/dir/..", 301, ""},
356 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
357
358
359
360 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
361 {"CONNECT", "google.com", "/../search", 404, ""},
362 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
363 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
364 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
365 }
366
367 func TestServeMuxHandler(t *testing.T) {
368 setParallel(t)
369 mux := NewServeMux()
370 for _, e := range serveMuxRegister {
371 mux.Handle(e.pattern, e.h)
372 }
373
374 for _, tt := range serveMuxTests {
375 r := &Request{
376 Method: tt.method,
377 Host: tt.host,
378 URL: &url.URL{
379 Path: tt.path,
380 },
381 }
382 h, pattern := mux.Handler(r)
383 rr := httptest.NewRecorder()
384 h.ServeHTTP(rr, r)
385 if pattern != tt.pattern || rr.Code != tt.code {
386 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
387 }
388 }
389 }
390
391
392 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
393 setParallel(t)
394 defer func() {
395 if err := recover(); err == nil {
396 t.Error("expected call to mux.HandleFunc to panic")
397 }
398 }()
399 mux := NewServeMux()
400 mux.HandleFunc("/", nil)
401 }
402
403 var serveMuxTests2 = []struct {
404 method string
405 host string
406 url string
407 code int
408 redirOk bool
409 }{
410 {"GET", "google.com", "/", 404, false},
411 {"GET", "example.com", "/test/?example.com/test/", 200, false},
412 {"GET", "example.com", "test/?example.com/test/", 200, true},
413 }
414
415
416
417 func TestServeMuxHandlerRedirects(t *testing.T) {
418 setParallel(t)
419 mux := NewServeMux()
420 for _, e := range serveMuxRegister {
421 mux.Handle(e.pattern, e.h)
422 }
423
424 for _, tt := range serveMuxTests2 {
425 tries := 1
426 turl := tt.url
427 for {
428 u, e := url.Parse(turl)
429 if e != nil {
430 t.Fatal(e)
431 }
432 r := &Request{
433 Method: tt.method,
434 Host: tt.host,
435 URL: u,
436 }
437 h, _ := mux.Handler(r)
438 rr := httptest.NewRecorder()
439 h.ServeHTTP(rr, r)
440 if rr.Code != 301 {
441 if rr.Code != tt.code {
442 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
443 }
444 break
445 }
446 if !tt.redirOk {
447 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
448 break
449 }
450 turl = rr.HeaderMap.Get("Location")
451 tries--
452 }
453 if tries < 0 {
454 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
455 }
456 }
457 }
458
459
460 func TestMuxRedirectLeadingSlashes(t *testing.T) {
461 setParallel(t)
462 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
463 for _, path := range paths {
464 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
465 if err != nil {
466 t.Errorf("%s", err)
467 }
468 mux := NewServeMux()
469 resp := httptest.NewRecorder()
470
471 mux.ServeHTTP(resp, req)
472
473 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
474 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
475 return
476 }
477
478 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
479 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
480 return
481 }
482 }
483 }
484
485
486
487
488
489 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
490 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
491 }
492 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
493 writeBackQuery := func(w ResponseWriter, r *Request) {
494 fmt.Fprintf(w, "%s", r.URL.RawQuery)
495 }
496
497 mux := NewServeMux()
498 mux.HandleFunc("/testOne", writeBackQuery)
499 mux.HandleFunc("/testTwo/", writeBackQuery)
500 mux.HandleFunc("/testThree", writeBackQuery)
501 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
502 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
503 })
504
505 ts := newClientServerTest(t, mode, mux).ts
506
507 tests := [...]struct {
508 path string
509 method string
510 want string
511 statusOk bool
512 }{
513 0: {"/testOne?this=that", "GET", "this=that", true},
514 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
515 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
516 3: {"/testTwo?", "GET", "", true},
517 4: {"/testThree?foo", "GET", "foo", true},
518 5: {"/testThree/?foo", "GET", "foo:bar", true},
519 6: {"/testThree?foo", "CONNECT", "foo", true},
520 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
521
522
523 8: {"/testOne/foo/..?foo", "GET", "foo", true},
524 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
525 }
526
527 for i, tt := range tests {
528 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
529 res, err := ts.Client().Do(req)
530 if err != nil {
531 continue
532 }
533 slurp, _ := io.ReadAll(res.Body)
534 res.Body.Close()
535 if !tt.statusOk {
536 if got, want := res.StatusCode, 404; got != want {
537 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
538 }
539 }
540 if got, want := string(slurp), tt.want; got != want {
541 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
542 }
543 }
544 }
545
546 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
547 setParallel(t)
548
549 mux := NewServeMux()
550 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
551 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
552 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
553 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
554 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
555 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
556
557 tests := []struct {
558 method string
559 url string
560 code int
561 loc string
562 want string
563 }{
564 {"GET", "http://example.com/", 404, "", ""},
565 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
566 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
567 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
568 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
569 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
570 {"CONNECT", "http://example.com/", 404, "", ""},
571 {"CONNECT", "http://example.com:3000/", 404, "", ""},
572 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
573 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
574 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
575 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
576 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
577 }
578
579 for i, tt := range tests {
580 req, _ := NewRequest(tt.method, tt.url, nil)
581 w := httptest.NewRecorder()
582 mux.ServeHTTP(w, req)
583
584 if got, want := w.Code, tt.code; got != want {
585 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
586 }
587
588 if tt.code == 301 {
589 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
590 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
591 }
592 } else {
593 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
594 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
595 }
596 }
597 }
598 }
599
600 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
601 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
602 mux := NewServeMux()
603 newClientServerTest(t, mode, mux)
604 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
605 }
606
607 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
608 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
609 func benchmarkServeMux(b *testing.B, runHandler bool) {
610 type test struct {
611 path string
612 code int
613 req *Request
614 }
615
616
617 var tests []test
618 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
619 for _, e := range endpoints {
620 for i := 200; i < 230; i++ {
621 p := fmt.Sprintf("/%s/%d/", e, i)
622 tests = append(tests, test{
623 path: p,
624 code: i,
625 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
626 })
627 }
628 }
629 mux := NewServeMux()
630 for _, tt := range tests {
631 mux.Handle(tt.path, serve(tt.code))
632 }
633
634 rw := httptest.NewRecorder()
635 b.ReportAllocs()
636 b.ResetTimer()
637 for i := 0; i < b.N; i++ {
638 for _, tt := range tests {
639 *rw = httptest.ResponseRecorder{}
640 h, pattern := mux.Handler(tt.req)
641 if runHandler {
642 h.ServeHTTP(rw, tt.req)
643 if pattern != tt.path || rw.Code != tt.code {
644 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
645 }
646 }
647 }
648 }
649 }
650
651 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
652 func testServerTimeouts(t *testing.T, mode testMode) {
653 runTimeSensitiveTest(t, []time.Duration{
654 10 * time.Millisecond,
655 50 * time.Millisecond,
656 100 * time.Millisecond,
657 500 * time.Millisecond,
658 1 * time.Second,
659 }, func(t *testing.T, timeout time.Duration) error {
660 return testServerTimeoutsWithTimeout(t, timeout, mode)
661 })
662 }
663
664 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
665 var reqNum atomic.Int32
666 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
667 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
668 }), func(ts *httptest.Server) {
669 ts.Config.ReadTimeout = timeout
670 ts.Config.WriteTimeout = timeout
671 })
672 defer cst.close()
673 ts := cst.ts
674
675
676 c := ts.Client()
677 r, err := c.Get(ts.URL)
678 if err != nil {
679 return fmt.Errorf("http Get #1: %v", err)
680 }
681 got, err := io.ReadAll(r.Body)
682 expected := "req=1"
683 if string(got) != expected || err != nil {
684 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
685 string(got), err, expected)
686 }
687
688
689 t1 := time.Now()
690 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
691 if err != nil {
692 return fmt.Errorf("Dial: %v", err)
693 }
694 buf := make([]byte, 1)
695 n, err := conn.Read(buf)
696 conn.Close()
697 latency := time.Since(t1)
698 if n != 0 || err != io.EOF {
699 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
700 }
701 minLatency := timeout / 5 * 4
702 if latency < minLatency {
703 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
704 }
705
706
707
708
709 r, err = c.Get(ts.URL)
710 if err != nil {
711 return fmt.Errorf("http Get #2: %v", err)
712 }
713 got, err = io.ReadAll(r.Body)
714 r.Body.Close()
715 expected = "req=2"
716 if string(got) != expected || err != nil {
717 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
718 }
719
720 if !testing.Short() {
721 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
722 if err != nil {
723 return fmt.Errorf("long Dial: %v", err)
724 }
725 defer conn.Close()
726 go io.Copy(io.Discard, conn)
727 for i := 0; i < 5; i++ {
728 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
729 if err != nil {
730 return fmt.Errorf("on write %d: %v", i, err)
731 }
732 time.Sleep(timeout / 2)
733 }
734 }
735 return nil
736 }
737
738 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
739 func testServerReadTimeout(t *testing.T, mode testMode) {
740 respBody := "response body"
741 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
742 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
743 _, err := io.Copy(io.Discard, req.Body)
744 if !errors.Is(err, os.ErrDeadlineExceeded) {
745 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
746 }
747 res.Write([]byte(respBody))
748 }), func(ts *httptest.Server) {
749 ts.Config.ReadHeaderTimeout = -1
750 ts.Config.ReadTimeout = timeout
751 })
752 pr, pw := io.Pipe()
753 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
754 if err != nil {
755 t.Logf("Get error, retrying: %v", err)
756 cst.close()
757 continue
758 }
759 defer res.Body.Close()
760 got, err := io.ReadAll(res.Body)
761 if string(got) != respBody || err != nil {
762 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
763 }
764 pw.Close()
765 break
766 }
767 }
768
769 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
770 func testServerWriteTimeout(t *testing.T, mode testMode) {
771 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
772 errc := make(chan error, 2)
773 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
774 errc <- nil
775 _, err := io.Copy(res, neverEnding('a'))
776 errc <- err
777 }), func(ts *httptest.Server) {
778 ts.Config.WriteTimeout = timeout
779 })
780 res, err := cst.c.Get(cst.ts.URL)
781 if err != nil {
782
783 t.Logf("Get error, retrying: %v", err)
784 cst.close()
785 continue
786 }
787 defer res.Body.Close()
788 _, err = io.Copy(io.Discard, res.Body)
789 if err == nil {
790 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
791 }
792 select {
793 case <-errc:
794 err = <-errc
795 if !errors.Is(err, os.ErrDeadlineExceeded) {
796 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
797 }
798 return
799 default:
800
801 t.Logf("handler didn't run, retrying")
802 cst.close()
803 }
804 }
805 }
806
807
808 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
809 run(t, testWriteDeadlineExtendedOnNewRequest)
810 }
811 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
812 if testing.Short() {
813 t.Skip("skipping in short mode")
814 }
815 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
816 func(ts *httptest.Server) {
817 ts.Config.WriteTimeout = 250 * time.Millisecond
818 },
819 ).ts
820
821 c := ts.Client()
822
823 for i := 1; i <= 3; i++ {
824 req, err := NewRequest("GET", ts.URL, nil)
825 if err != nil {
826 t.Fatal(err)
827 }
828
829 r, err := c.Do(req)
830 if err != nil {
831 t.Fatalf("http2 Get #%d: %v", i, err)
832 }
833 r.Body.Close()
834 time.Sleep(ts.Config.WriteTimeout / 2)
835 }
836 }
837
838
839
840 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
841 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
842 for i, timeout := range tries {
843 err := testFunc(timeout)
844 if err == nil {
845 return
846 }
847 t.Logf("failed at %v: %v", timeout, err)
848 if i != len(tries)-1 {
849 t.Logf("retrying at %v ...", tries[i+1])
850 }
851 }
852 t.Fatal("all attempts failed")
853 }
854
855
856 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
857 if testing.Short() {
858 t.Skip("skipping in short mode")
859 }
860 setParallel(t)
861 run(t, func(t *testing.T, mode testMode) {
862 tryTimeouts(t, func(timeout time.Duration) error {
863 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
864 })
865 })
866 }
867
868 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
869 firstRequest := make(chan bool, 1)
870 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
871 select {
872 case firstRequest <- true:
873
874 default:
875
876 time.Sleep(timeout)
877 }
878 }), func(ts *httptest.Server) {
879 ts.Config.WriteTimeout = timeout / 2
880 })
881 defer cst.close()
882 ts := cst.ts
883
884 c := ts.Client()
885
886 req, err := NewRequest("GET", ts.URL, nil)
887 if err != nil {
888 return fmt.Errorf("NewRequest: %v", err)
889 }
890 r, err := c.Do(req)
891 if err != nil {
892 return fmt.Errorf("Get #1: %v", err)
893 }
894 r.Body.Close()
895
896 req, err = NewRequest("GET", ts.URL, nil)
897 if err != nil {
898 return fmt.Errorf("NewRequest: %v", err)
899 }
900 r, err = c.Do(req)
901 if err == nil {
902 r.Body.Close()
903 return fmt.Errorf("Get #2 expected error, got nil")
904 }
905 if mode == http2Mode {
906 expected := "stream ID 3; INTERNAL_ERROR"
907 if !strings.Contains(err.Error(), expected) {
908 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
909 }
910 }
911 return nil
912 }
913
914
915 func TestNoWriteDeadline(t *testing.T) {
916 if testing.Short() {
917 t.Skip("skipping in short mode")
918 }
919 setParallel(t)
920 defer afterTest(t)
921 run(t, func(t *testing.T, mode testMode) {
922 tryTimeouts(t, func(timeout time.Duration) error {
923 return testNoWriteDeadline(t, mode, timeout)
924 })
925 })
926 }
927
928 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
929 firstRequest := make(chan bool, 1)
930 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
931 select {
932 case firstRequest <- true:
933
934 default:
935
936 time.Sleep(timeout)
937 }
938 }))
939 defer cst.close()
940 ts := cst.ts
941
942 c := ts.Client()
943
944 for i := 0; i < 2; i++ {
945 req, err := NewRequest("GET", ts.URL, nil)
946 if err != nil {
947 return fmt.Errorf("NewRequest: %v", err)
948 }
949 r, err := c.Do(req)
950 if err != nil {
951 return fmt.Errorf("Get #%d: %v", i, err)
952 }
953 r.Body.Close()
954 }
955 return nil
956 }
957
958
959
960
961 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
962 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
963 var (
964 mu sync.RWMutex
965 conn net.Conn
966 )
967 var afterTimeoutErrc = make(chan error, 1)
968 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
969 buf := make([]byte, 512<<10)
970 _, err := w.Write(buf)
971 if err != nil {
972 t.Errorf("handler Write error: %v", err)
973 return
974 }
975 mu.RLock()
976 defer mu.RUnlock()
977 if conn == nil {
978 t.Error("no established connection found")
979 return
980 }
981 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
982 _, err = w.Write(buf)
983 afterTimeoutErrc <- err
984 }), func(ts *httptest.Server) {
985 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
986 }).ts
987
988 c := ts.Client()
989
990 err := func() error {
991 res, err := c.Get(ts.URL)
992 if err != nil {
993 return err
994 }
995 _, err = io.Copy(io.Discard, res.Body)
996 res.Body.Close()
997 return err
998 }()
999 if err == nil {
1000 t.Errorf("expected an error copying body from Get request")
1001 }
1002
1003 if err := <-afterTimeoutErrc; err == nil {
1004 t.Error("expected write error after timeout")
1005 }
1006 }
1007
1008
1009 type trackLastConnListener struct {
1010 net.Listener
1011
1012 mu *sync.RWMutex
1013 last *net.Conn
1014 }
1015
1016 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1017 c, err = l.Listener.Accept()
1018 if err == nil {
1019 l.mu.Lock()
1020 *l.last = c
1021 l.mu.Unlock()
1022 }
1023 return
1024 }
1025
1026
1027 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1028 func testIdentityResponse(t *testing.T, mode testMode) {
1029 if mode == http2Mode {
1030 t.Skip("https://go.dev/issue/56019")
1031 }
1032
1033 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1034 rw.Header().Set("Content-Length", "3")
1035 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1036 switch {
1037 case req.FormValue("overwrite") == "1":
1038 _, err := rw.Write([]byte("foo TOO LONG"))
1039 if err != ErrContentLength {
1040 t.Errorf("expected ErrContentLength; got %v", err)
1041 }
1042 case req.FormValue("underwrite") == "1":
1043 rw.Header().Set("Content-Length", "500")
1044 rw.Write([]byte("too short"))
1045 default:
1046 rw.Write([]byte("foo"))
1047 }
1048 })
1049
1050 ts := newClientServerTest(t, mode, handler).ts
1051 c := ts.Client()
1052
1053
1054
1055
1056
1057 for _, te := range []string{"", "identity"} {
1058 url := ts.URL + "/?te=" + te
1059 res, err := c.Get(url)
1060 if err != nil {
1061 t.Fatalf("error with Get of %s: %v", url, err)
1062 }
1063 if cl, expected := res.ContentLength, int64(3); cl != expected {
1064 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1065 }
1066 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1067 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1068 }
1069 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1070 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1071 url, expected, tl, res.TransferEncoding)
1072 }
1073 res.Body.Close()
1074 }
1075
1076
1077 url := ts.URL + "/?overwrite=1"
1078 res, err := c.Get(url)
1079 if err != nil {
1080 t.Fatalf("error with Get of %s: %v", url, err)
1081 }
1082 res.Body.Close()
1083
1084 if mode != http1Mode {
1085 return
1086 }
1087
1088
1089
1090 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1091 if err != nil {
1092 t.Fatalf("error dialing: %v", err)
1093 }
1094 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1095 if err != nil {
1096 t.Fatalf("error writing: %v", err)
1097 }
1098
1099
1100 got, _ := io.ReadAll(conn)
1101 expectedSuffix := "\r\n\r\ntoo short"
1102 if !strings.HasSuffix(string(got), expectedSuffix) {
1103 t.Errorf("Expected output to end with %q; got response body %q",
1104 expectedSuffix, string(got))
1105 }
1106 }
1107
1108 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1109 setParallel(t)
1110 s := newClientServerTest(t, http1Mode, h).ts
1111
1112 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1113 if err != nil {
1114 t.Fatal("dial error:", err)
1115 }
1116 defer conn.Close()
1117
1118 _, err = fmt.Fprint(conn, req)
1119 if err != nil {
1120 t.Fatal("print error:", err)
1121 }
1122
1123 r := bufio.NewReader(conn)
1124 res, err := ReadResponse(r, &Request{Method: "GET"})
1125 if err != nil {
1126 t.Fatal("ReadResponse error:", err)
1127 }
1128
1129 _, err = io.ReadAll(r)
1130 if err != nil {
1131 t.Fatal("read error:", err)
1132 }
1133
1134 if !res.Close {
1135 t.Errorf("Response.Close = false; want true")
1136 }
1137 }
1138
1139 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1140 setParallel(t)
1141 ts := newClientServerTest(t, http1Mode, handler).ts
1142 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1143 if err != nil {
1144 t.Fatal(err)
1145 }
1146 defer conn.Close()
1147 br := bufio.NewReader(conn)
1148 for i := 0; i < 2; i++ {
1149 if _, err := io.WriteString(conn, req); err != nil {
1150 t.Fatal(err)
1151 }
1152 res, err := ReadResponse(br, nil)
1153 if err != nil {
1154 t.Fatalf("res %d: %v", i+1, err)
1155 }
1156 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1157 t.Fatalf("res %d body copy: %v", i+1, err)
1158 }
1159 res.Body.Close()
1160 }
1161 }
1162
1163
1164 func TestServeHTTP10Close(t *testing.T) {
1165 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1166 ServeFile(w, r, "testdata/file")
1167 }))
1168 }
1169
1170
1171 func TestClientCanClose(t *testing.T) {
1172 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1173
1174 }))
1175 }
1176
1177
1178
1179 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1180 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1181 w.Header().Set("Connection", "close")
1182 }))
1183 }
1184
1185 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1186 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1187 w.Header().Set("Connection", "close")
1188 }))
1189 }
1190
1191 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1192 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1193
1194
1195 }))
1196 }
1197
1198 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1199 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1200
1201
1202 func TestHTTP10KeepAlive204Response(t *testing.T) {
1203 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1204 }
1205
1206 func TestHTTP11KeepAlive204Response(t *testing.T) {
1207 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1208 }
1209
1210 func TestHTTP10KeepAlive304Response(t *testing.T) {
1211 testTCPConnectionStaysOpen(t,
1212 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1213 HandlerFunc(send304))
1214 }
1215
1216
1217 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1218 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1219 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1220 w.(Flusher).Flush()
1221 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1222 }))
1223 type data struct {
1224 Addr string
1225 }
1226 var addrs [2]data
1227 for i := range addrs {
1228 res, err := cst.c.Get(cst.ts.URL)
1229 if err != nil {
1230 t.Fatal(err)
1231 }
1232 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1233 t.Fatal(err)
1234 }
1235 if addrs[i].Addr == "" {
1236 t.Fatal("no address")
1237 }
1238 res.Body.Close()
1239 }
1240 if addrs[0] != addrs[1] {
1241 t.Fatalf("connection not reused")
1242 }
1243 }
1244
1245 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1246 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1247 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1248 fmt.Fprintf(w, "%s", r.RemoteAddr)
1249 }))
1250
1251 res, err := cst.c.Get(cst.ts.URL)
1252 if err != nil {
1253 t.Fatalf("Get error: %v", err)
1254 }
1255 body, err := io.ReadAll(res.Body)
1256 if err != nil {
1257 t.Fatalf("ReadAll error: %v", err)
1258 }
1259 ip := string(body)
1260 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1261 t.Fatalf("Expected local addr; got %q", ip)
1262 }
1263 }
1264
1265 type blockingRemoteAddrListener struct {
1266 net.Listener
1267 conns chan<- net.Conn
1268 }
1269
1270 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1271 c, err := l.Listener.Accept()
1272 if err != nil {
1273 return nil, err
1274 }
1275 brac := &blockingRemoteAddrConn{
1276 Conn: c,
1277 addrs: make(chan net.Addr, 1),
1278 }
1279 l.conns <- brac
1280 return brac, nil
1281 }
1282
1283 type blockingRemoteAddrConn struct {
1284 net.Conn
1285 addrs chan net.Addr
1286 }
1287
1288 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1289 return <-c.addrs
1290 }
1291
1292
1293 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1294 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1295 }
1296 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1297 conns := make(chan net.Conn)
1298 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1299 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1300 }), func(ts *httptest.Server) {
1301 ts.Listener = &blockingRemoteAddrListener{
1302 Listener: ts.Listener,
1303 conns: conns,
1304 }
1305 }).ts
1306
1307 c := ts.Client()
1308
1309 c.Transport.(*Transport).DisableKeepAlives = true
1310
1311 fetch := func(num int, response chan<- string) {
1312 resp, err := c.Get(ts.URL)
1313 if err != nil {
1314 t.Errorf("Request %d: %v", num, err)
1315 response <- ""
1316 return
1317 }
1318 defer resp.Body.Close()
1319 body, err := io.ReadAll(resp.Body)
1320 if err != nil {
1321 t.Errorf("Request %d: %v", num, err)
1322 response <- ""
1323 return
1324 }
1325 response <- string(body)
1326 }
1327
1328
1329 response1c := make(chan string, 1)
1330 go fetch(1, response1c)
1331
1332
1333 conn1 := <-conns
1334
1335
1336 response2c := make(chan string, 1)
1337 go fetch(2, response2c)
1338 conn2 := <-conns
1339
1340
1341 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1342 IP: net.ParseIP("12.12.12.12"), Port: 12}
1343
1344
1345 response2 := <-response2c
1346 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1347 t.Fatalf("response 2 addr = %q; want %q", g, e)
1348 }
1349
1350
1351 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1352 IP: net.ParseIP("21.21.21.21"), Port: 21}
1353
1354
1355 response1 := <-response1c
1356 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1357 t.Fatalf("response 1 addr = %q; want %q", g, e)
1358 }
1359 }
1360
1361
1362
1363 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1364 func testHeadResponses(t *testing.T, mode testMode) {
1365 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1366 _, err := w.Write([]byte("<html>"))
1367 if err != nil {
1368 t.Errorf("ResponseWriter.Write: %v", err)
1369 }
1370
1371
1372 _, err = io.Copy(w, strings.NewReader("789a"))
1373 if err != nil {
1374 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1375 }
1376 }))
1377 res, err := cst.c.Head(cst.ts.URL)
1378 if err != nil {
1379 t.Error(err)
1380 }
1381 if len(res.TransferEncoding) > 0 {
1382 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1383 }
1384 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1385 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1386 }
1387 if v := res.ContentLength; v != 10 {
1388 t.Errorf("Content-Length: %d; want 10", v)
1389 }
1390 body, err := io.ReadAll(res.Body)
1391 if err != nil {
1392 t.Error(err)
1393 }
1394 if len(body) > 0 {
1395 t.Errorf("got unexpected body %q", string(body))
1396 }
1397 }
1398
1399 func TestTLSHandshakeTimeout(t *testing.T) {
1400 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1401 }
1402 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1403 errLog := new(strings.Builder)
1404 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1405 func(ts *httptest.Server) {
1406 ts.Config.ReadTimeout = 250 * time.Millisecond
1407 ts.Config.ErrorLog = log.New(errLog, "", 0)
1408 },
1409 )
1410 ts := cst.ts
1411
1412 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1413 if err != nil {
1414 t.Fatalf("Dial: %v", err)
1415 }
1416 var buf [1]byte
1417 n, err := conn.Read(buf[:])
1418 if err == nil || n != 0 {
1419 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1420 }
1421 conn.Close()
1422
1423 cst.close()
1424 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1425 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1426 }
1427 }
1428
1429 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1430 func testTLSServer(t *testing.T, mode testMode) {
1431 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1432 if r.TLS != nil {
1433 w.Header().Set("X-TLS-Set", "true")
1434 if r.TLS.HandshakeComplete {
1435 w.Header().Set("X-TLS-HandshakeComplete", "true")
1436 }
1437 }
1438 }), func(ts *httptest.Server) {
1439 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1440 }).ts
1441
1442
1443
1444
1445
1446
1447 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1448 if err != nil {
1449 t.Fatalf("Dial: %v", err)
1450 }
1451 defer idleConn.Close()
1452
1453 if !strings.HasPrefix(ts.URL, "https://") {
1454 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1455 return
1456 }
1457 client := ts.Client()
1458 res, err := client.Get(ts.URL)
1459 if err != nil {
1460 t.Error(err)
1461 return
1462 }
1463 if res == nil {
1464 t.Errorf("got nil Response")
1465 return
1466 }
1467 defer res.Body.Close()
1468 if res.Header.Get("X-TLS-Set") != "true" {
1469 t.Errorf("expected X-TLS-Set response header")
1470 return
1471 }
1472 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1473 t.Errorf("expected X-TLS-HandshakeComplete header")
1474 }
1475 }
1476
1477 func TestServeTLS(t *testing.T) {
1478 CondSkipHTTP2(t)
1479
1480 defer afterTest(t)
1481 defer SetTestHookServerServe(nil)
1482
1483 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1484 if err != nil {
1485 t.Fatal(err)
1486 }
1487 tlsConf := &tls.Config{
1488 Certificates: []tls.Certificate{cert},
1489 }
1490
1491 ln := newLocalListener(t)
1492 defer ln.Close()
1493 addr := ln.Addr().String()
1494
1495 serving := make(chan bool, 1)
1496 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1497 serving <- true
1498 })
1499 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1500 s := &Server{
1501 Addr: addr,
1502 TLSConfig: tlsConf,
1503 Handler: handler,
1504 }
1505 errc := make(chan error, 1)
1506 go func() { errc <- s.ServeTLS(ln, "", "") }()
1507 select {
1508 case err := <-errc:
1509 t.Fatalf("ServeTLS: %v", err)
1510 case <-serving:
1511 }
1512
1513 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1514 InsecureSkipVerify: true,
1515 NextProtos: []string{"h2", "http/1.1"},
1516 })
1517 if err != nil {
1518 t.Fatal(err)
1519 }
1520 defer c.Close()
1521 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1522 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1523 }
1524 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1525 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1526 }
1527 }
1528
1529
1530 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1531 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1532 }
1533 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1534 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1535 t.Error("unexpected HTTPS request")
1536 }), func(ts *httptest.Server) {
1537 var errBuf bytes.Buffer
1538 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1539 }).ts
1540 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1541 if err != nil {
1542 t.Fatal(err)
1543 }
1544 defer conn.Close()
1545 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1546 slurp, err := io.ReadAll(conn)
1547 if err != nil {
1548 t.Fatal(err)
1549 }
1550 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1551 if !strings.HasPrefix(string(slurp), wantPrefix) {
1552 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1553 }
1554 }
1555
1556
1557 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1558 testAutomaticHTTP2_Serve(t, nil, true)
1559 }
1560
1561 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1562 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1563 }
1564
1565 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1566 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1567 }
1568
1569 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1570 setParallel(t)
1571 defer afterTest(t)
1572 ln := newLocalListener(t)
1573 ln.Close()
1574 var s Server
1575 s.TLSConfig = tlsConf
1576 if err := s.Serve(ln); err == nil {
1577 t.Fatal("expected an error")
1578 }
1579 gotH2 := s.TLSNextProto["h2"] != nil
1580 if gotH2 != wantH2 {
1581 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1582 }
1583 }
1584
1585 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1586 setParallel(t)
1587 defer afterTest(t)
1588 ln := newLocalListener(t)
1589 ln.Close()
1590 var s Server
1591
1592
1593 s.TLSConfig = &tls.Config{
1594 NextProtos: []string{"h2"},
1595 }
1596 if err := s.Serve(ln); err == nil {
1597 t.Fatal("expected an error")
1598 }
1599 on := s.TLSNextProto["h2"] != nil
1600 if !on {
1601 t.Errorf("http2 wasn't automatically enabled")
1602 }
1603 }
1604
1605 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1606 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1607 if err != nil {
1608 t.Fatal(err)
1609 }
1610 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1611 Certificates: []tls.Certificate{cert},
1612 })
1613 }
1614
1615 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1616 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1617 if err != nil {
1618 t.Fatal(err)
1619 }
1620 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1621 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1622 return &cert, nil
1623 },
1624 })
1625 }
1626
1627 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1628 CondSkipHTTP2(t)
1629
1630 defer afterTest(t)
1631 defer SetTestHookServerServe(nil)
1632 var ok bool
1633 var s *Server
1634 const maxTries = 5
1635 var ln net.Listener
1636 Try:
1637 for try := 0; try < maxTries; try++ {
1638 ln = newLocalListener(t)
1639 addr := ln.Addr().String()
1640 ln.Close()
1641 t.Logf("Got %v", addr)
1642 lnc := make(chan net.Listener, 1)
1643 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1644 lnc <- ln
1645 })
1646 s = &Server{
1647 Addr: addr,
1648 TLSConfig: tlsConf,
1649 }
1650 errc := make(chan error, 1)
1651 go func() { errc <- s.ListenAndServeTLS("", "") }()
1652 select {
1653 case err := <-errc:
1654 t.Logf("On try #%v: %v", try+1, err)
1655 continue
1656 case ln = <-lnc:
1657 ok = true
1658 t.Logf("Listening on %v", ln.Addr().String())
1659 break Try
1660 }
1661 }
1662 if !ok {
1663 t.Fatalf("Failed to start up after %d tries", maxTries)
1664 }
1665 defer ln.Close()
1666 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1667 InsecureSkipVerify: true,
1668 NextProtos: []string{"h2", "http/1.1"},
1669 })
1670 if err != nil {
1671 t.Fatal(err)
1672 }
1673 defer c.Close()
1674 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1675 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1676 }
1677 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1678 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1679 }
1680 }
1681
1682 type serverExpectTest struct {
1683 contentLength int
1684 chunked bool
1685 expectation string
1686 readBody bool
1687 expectedResponse string
1688 }
1689
1690 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1691 return serverExpectTest{
1692 contentLength: contentLength,
1693 expectation: expectation,
1694 readBody: readBody,
1695 expectedResponse: expectedResponse,
1696 }
1697 }
1698
1699 var serverExpectTests = []serverExpectTest{
1700
1701 expectTest(100, "100-continue", true, "100 Continue"),
1702 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1703
1704
1705 expectTest(100, "", true, "200 OK"),
1706
1707
1708
1709 expectTest(100, "100-continue", false, "401 Unauthorized"),
1710
1711 expectTest(100, "", false, "401 Unauthorized"),
1712
1713
1714 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1715
1716
1717 expectTest(0, "100-continue", true, "200 OK"),
1718
1719 expectTest(0, "100-continue", false, "401 Unauthorized"),
1720
1721 {
1722 expectation: "100-continue",
1723 readBody: true,
1724 chunked: true,
1725 expectedResponse: "100 Continue",
1726 },
1727 }
1728
1729
1730
1731 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1732 func testServerExpect(t *testing.T, mode testMode) {
1733 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1734
1735
1736
1737 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1738 io.ReadAll(r.Body)
1739 w.Write([]byte("Hi"))
1740 } else {
1741 w.WriteHeader(StatusUnauthorized)
1742 }
1743 })).ts
1744
1745 runTest := func(test serverExpectTest) {
1746 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1747 if err != nil {
1748 t.Fatalf("Dial: %v", err)
1749 }
1750 defer conn.Close()
1751
1752
1753
1754 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1755
1756 wg := sync.WaitGroup{}
1757 wg.Add(1)
1758 defer wg.Wait()
1759
1760 go func() {
1761 defer wg.Done()
1762
1763 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1764 if test.chunked {
1765 contentLen = "Transfer-Encoding: chunked"
1766 }
1767 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1768 "Connection: close\r\n"+
1769 "%s\r\n"+
1770 "Expect: %s\r\nHost: foo\r\n\r\n",
1771 test.readBody, contentLen, test.expectation)
1772 if err != nil {
1773 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1774 return
1775 }
1776 if writeBody {
1777 var targ io.WriteCloser = struct {
1778 io.Writer
1779 io.Closer
1780 }{
1781 conn,
1782 io.NopCloser(nil),
1783 }
1784 if test.chunked {
1785 targ = httputil.NewChunkedWriter(conn)
1786 }
1787 body := strings.Repeat("A", test.contentLength)
1788 _, err = fmt.Fprint(targ, body)
1789 if err == nil {
1790 err = targ.Close()
1791 }
1792 if err != nil {
1793 if !test.readBody {
1794
1795
1796 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1797 return
1798 }
1799 t.Errorf("On test %#v, error writing request body: %v", test, err)
1800 }
1801 }
1802 }()
1803 bufr := bufio.NewReader(conn)
1804 line, err := bufr.ReadString('\n')
1805 if err != nil {
1806 if writeBody && !test.readBody {
1807
1808
1809
1810
1811
1812 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
1813 return
1814 }
1815 t.Fatalf("On test %#v, ReadString: %v", test, err)
1816 }
1817 if !strings.Contains(line, test.expectedResponse) {
1818 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
1819 }
1820 }
1821
1822 for _, test := range serverExpectTests {
1823 runTest(test)
1824 }
1825 }
1826
1827
1828
1829 func TestServerUnreadRequestBodyLittle(t *testing.T) {
1830 setParallel(t)
1831 defer afterTest(t)
1832 conn := new(testConn)
1833 body := strings.Repeat("x", 100<<10)
1834 conn.readBuf.Write([]byte(fmt.Sprintf(
1835 "POST / HTTP/1.1\r\n"+
1836 "Host: test\r\n"+
1837 "Content-Length: %d\r\n"+
1838 "\r\n", len(body))))
1839 conn.readBuf.Write([]byte(body))
1840
1841 done := make(chan bool)
1842
1843 readBufLen := func() int {
1844 conn.readMu.Lock()
1845 defer conn.readMu.Unlock()
1846 return conn.readBuf.Len()
1847 }
1848
1849 ls := &oneConnListener{conn}
1850 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1851 defer close(done)
1852 if bufLen := readBufLen(); bufLen < len(body)/2 {
1853 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
1854 }
1855 rw.WriteHeader(200)
1856 rw.(Flusher).Flush()
1857 if g, e := readBufLen(), 0; g != e {
1858 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
1859 }
1860 if c := rw.Header().Get("Connection"); c != "" {
1861 t.Errorf(`Connection header = %q; want ""`, c)
1862 }
1863 }))
1864 <-done
1865 }
1866
1867
1868
1869
1870 func TestServerUnreadRequestBodyLarge(t *testing.T) {
1871 setParallel(t)
1872 if testing.Short() && testenv.Builder() == "" {
1873 t.Log("skipping in short mode")
1874 }
1875 conn := new(testConn)
1876 body := strings.Repeat("x", 1<<20)
1877 conn.readBuf.Write([]byte(fmt.Sprintf(
1878 "POST / HTTP/1.1\r\n"+
1879 "Host: test\r\n"+
1880 "Content-Length: %d\r\n"+
1881 "\r\n", len(body))))
1882 conn.readBuf.Write([]byte(body))
1883 conn.closec = make(chan bool, 1)
1884
1885 ls := &oneConnListener{conn}
1886 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1887 if conn.readBuf.Len() < len(body)/2 {
1888 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1889 }
1890 rw.WriteHeader(200)
1891 rw.(Flusher).Flush()
1892 if conn.readBuf.Len() < len(body)/2 {
1893 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1894 }
1895 }))
1896 <-conn.closec
1897
1898 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
1899 t.Errorf("Expected a Connection: close header; got response: %s", res)
1900 }
1901 }
1902
1903 type handlerBodyCloseTest struct {
1904 bodySize int
1905 bodyChunked bool
1906 reqConnClose bool
1907
1908 wantEOFSearch bool
1909 wantNextReq bool
1910 }
1911
1912 func (t handlerBodyCloseTest) connectionHeader() string {
1913 if t.reqConnClose {
1914 return "Connection: close\r\n"
1915 }
1916 return ""
1917 }
1918
1919 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
1920
1921
1922 0: {
1923 bodySize: 20 << 10,
1924 bodyChunked: false,
1925 reqConnClose: false,
1926 wantEOFSearch: true,
1927 wantNextReq: true,
1928 },
1929
1930
1931
1932 1: {
1933 bodySize: 20 << 10,
1934 bodyChunked: true,
1935 reqConnClose: false,
1936 wantEOFSearch: true,
1937 wantNextReq: true,
1938 },
1939
1940
1941
1942
1943 2: {
1944 bodySize: 20 << 10,
1945 bodyChunked: false,
1946 reqConnClose: true,
1947 wantEOFSearch: false,
1948 wantNextReq: false,
1949 },
1950
1951
1952
1953
1954
1955
1956 3: {
1957 bodySize: 20 << 10,
1958 bodyChunked: true,
1959 reqConnClose: true,
1960 wantEOFSearch: true,
1961 wantNextReq: false,
1962 },
1963
1964
1965 4: {
1966 bodySize: 1 << 20,
1967 bodyChunked: false,
1968 reqConnClose: false,
1969 wantEOFSearch: false,
1970 wantNextReq: false,
1971 },
1972
1973
1974 5: {
1975 bodySize: 1 << 20,
1976 bodyChunked: true,
1977 reqConnClose: false,
1978 wantEOFSearch: true,
1979 wantNextReq: false,
1980 },
1981
1982
1983
1984
1985 6: {
1986 bodySize: 1 << 20,
1987 bodyChunked: true,
1988 reqConnClose: true,
1989 wantEOFSearch: true,
1990 wantNextReq: false,
1991 },
1992
1993
1994
1995 7: {
1996 bodySize: 1 << 20,
1997 bodyChunked: false,
1998 reqConnClose: true,
1999 wantEOFSearch: false,
2000 wantNextReq: false,
2001 },
2002 }
2003
2004 func TestHandlerBodyClose(t *testing.T) {
2005 setParallel(t)
2006 if testing.Short() && testenv.Builder() == "" {
2007 t.Skip("skipping in -short mode")
2008 }
2009 for i, tt := range handlerBodyCloseTests {
2010 testHandlerBodyClose(t, i, tt)
2011 }
2012 }
2013
2014 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2015 conn := new(testConn)
2016 body := strings.Repeat("x", tt.bodySize)
2017 if tt.bodyChunked {
2018 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2019 "Host: test\r\n" +
2020 tt.connectionHeader() +
2021 "Transfer-Encoding: chunked\r\n" +
2022 "\r\n")
2023 cw := internal.NewChunkedWriter(&conn.readBuf)
2024 io.WriteString(cw, body)
2025 cw.Close()
2026 conn.readBuf.WriteString("\r\n")
2027 } else {
2028 conn.readBuf.Write([]byte(fmt.Sprintf(
2029 "POST / HTTP/1.1\r\n"+
2030 "Host: test\r\n"+
2031 tt.connectionHeader()+
2032 "Content-Length: %d\r\n"+
2033 "\r\n", len(body))))
2034 conn.readBuf.Write([]byte(body))
2035 }
2036 if !tt.reqConnClose {
2037 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2038 }
2039 conn.closec = make(chan bool, 1)
2040
2041 readBufLen := func() int {
2042 conn.readMu.Lock()
2043 defer conn.readMu.Unlock()
2044 return conn.readBuf.Len()
2045 }
2046
2047 ls := &oneConnListener{conn}
2048 var numReqs int
2049 var size0, size1 int
2050 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2051 numReqs++
2052 if numReqs == 1 {
2053 size0 = readBufLen()
2054 req.Body.Close()
2055 size1 = readBufLen()
2056 }
2057 }))
2058 <-conn.closec
2059 if numReqs < 1 || numReqs > 2 {
2060 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2061 }
2062 didSearch := size0 != size1
2063 if didSearch != tt.wantEOFSearch {
2064 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2065 }
2066 if tt.wantNextReq && numReqs != 2 {
2067 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2068 }
2069 }
2070
2071
2072
2073 type testHandlerBodyConsumer struct {
2074 name string
2075 f func(io.ReadCloser)
2076 }
2077
2078 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2079 {"nil", func(io.ReadCloser) {}},
2080 {"close", func(r io.ReadCloser) { r.Close() }},
2081 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2082 }
2083
2084 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2085 setParallel(t)
2086 defer afterTest(t)
2087 for _, handler := range testHandlerBodyConsumers {
2088 conn := new(testConn)
2089 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2090 "Host: test\r\n" +
2091 "Transfer-Encoding: chunked\r\n" +
2092 "\r\n" +
2093 "hax\r\n" +
2094 "GET /secret HTTP/1.1\r\n" +
2095 "Host: test\r\n" +
2096 "\r\n")
2097
2098 conn.closec = make(chan bool, 1)
2099 ls := &oneConnListener{conn}
2100 var numReqs int
2101 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2102 numReqs++
2103 if strings.Contains(req.URL.Path, "secret") {
2104 t.Error("Request for /secret encountered, should not have happened.")
2105 }
2106 handler.f(req.Body)
2107 }))
2108 <-conn.closec
2109 if numReqs != 1 {
2110 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2111 }
2112 }
2113 }
2114
2115 func TestInvalidTrailerClosesConnection(t *testing.T) {
2116 setParallel(t)
2117 defer afterTest(t)
2118 for _, handler := range testHandlerBodyConsumers {
2119 conn := new(testConn)
2120 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2121 "Host: test\r\n" +
2122 "Trailer: hack\r\n" +
2123 "Transfer-Encoding: chunked\r\n" +
2124 "\r\n" +
2125 "3\r\n" +
2126 "hax\r\n" +
2127 "0\r\n" +
2128 "I'm not a valid trailer\r\n" +
2129 "GET /secret HTTP/1.1\r\n" +
2130 "Host: test\r\n" +
2131 "\r\n")
2132
2133 conn.closec = make(chan bool, 1)
2134 ln := &oneConnListener{conn}
2135 var numReqs int
2136 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2137 numReqs++
2138 if strings.Contains(req.URL.Path, "secret") {
2139 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2140 }
2141 handler.f(req.Body)
2142 }))
2143 <-conn.closec
2144 if numReqs != 1 {
2145 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2146 }
2147 }
2148 }
2149
2150
2151
2152
2153 type slowTestConn struct {
2154
2155 script []any
2156 closec chan bool
2157
2158 mu sync.Mutex
2159 rd, wd time.Time
2160 noopConn
2161 }
2162
2163 func (c *slowTestConn) SetDeadline(t time.Time) error {
2164 c.SetReadDeadline(t)
2165 c.SetWriteDeadline(t)
2166 return nil
2167 }
2168
2169 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2170 c.mu.Lock()
2171 defer c.mu.Unlock()
2172 c.rd = t
2173 return nil
2174 }
2175
2176 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2177 c.mu.Lock()
2178 defer c.mu.Unlock()
2179 c.wd = t
2180 return nil
2181 }
2182
2183 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2184 c.mu.Lock()
2185 defer c.mu.Unlock()
2186 restart:
2187 if !c.rd.IsZero() && time.Now().After(c.rd) {
2188 return 0, syscall.ETIMEDOUT
2189 }
2190 if len(c.script) == 0 {
2191 return 0, io.EOF
2192 }
2193
2194 switch cue := c.script[0].(type) {
2195 case time.Duration:
2196 if !c.rd.IsZero() {
2197
2198
2199 if remaining := time.Until(c.rd); remaining < cue {
2200 c.script[0] = cue - remaining
2201 time.Sleep(remaining)
2202 return 0, syscall.ETIMEDOUT
2203 }
2204 }
2205 c.script = c.script[1:]
2206 time.Sleep(cue)
2207 goto restart
2208
2209 case string:
2210 n = copy(b, cue)
2211
2212 if len(cue) > n {
2213 c.script[0] = cue[n:]
2214 } else {
2215 c.script = c.script[1:]
2216 }
2217
2218 default:
2219 panic("unknown cue in slowTestConn script")
2220 }
2221
2222 return
2223 }
2224
2225 func (c *slowTestConn) Close() error {
2226 select {
2227 case c.closec <- true:
2228 default:
2229 }
2230 return nil
2231 }
2232
2233 func (c *slowTestConn) Write(b []byte) (int, error) {
2234 if !c.wd.IsZero() && time.Now().After(c.wd) {
2235 return 0, syscall.ETIMEDOUT
2236 }
2237 return len(b), nil
2238 }
2239
2240 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2241 if testing.Short() {
2242 t.Skip("skipping in -short mode")
2243 }
2244 defer afterTest(t)
2245 for _, handler := range testHandlerBodyConsumers {
2246 conn := &slowTestConn{
2247 script: []any{
2248 "POST /public HTTP/1.1\r\n" +
2249 "Host: test\r\n" +
2250 "Content-Length: 10000\r\n" +
2251 "\r\n",
2252 "foo bar baz",
2253 600 * time.Millisecond,
2254 "GET /secret HTTP/1.1\r\n" +
2255 "Host: test\r\n" +
2256 "\r\n",
2257 },
2258 closec: make(chan bool, 1),
2259 }
2260 ls := &oneConnListener{conn}
2261
2262 var numReqs int
2263 s := Server{
2264 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2265 numReqs++
2266 if strings.Contains(req.URL.Path, "secret") {
2267 t.Error("Request for /secret encountered, should not have happened.")
2268 }
2269 handler.f(req.Body)
2270 }),
2271 ReadTimeout: 400 * time.Millisecond,
2272 }
2273 go s.Serve(ls)
2274 <-conn.closec
2275
2276 if numReqs != 1 {
2277 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2278 }
2279 }
2280 }
2281
2282
2283 type cancelableTimeoutContext struct {
2284 context.Context
2285 }
2286
2287 func (c cancelableTimeoutContext) Err() error {
2288 if c.Context.Err() != nil {
2289 return context.DeadlineExceeded
2290 }
2291 return nil
2292 }
2293
2294 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2295 func testTimeoutHandler(t *testing.T, mode testMode) {
2296 sendHi := make(chan bool, 1)
2297 writeErrors := make(chan error, 1)
2298 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2299 <-sendHi
2300 _, werr := w.Write([]byte("hi"))
2301 writeErrors <- werr
2302 })
2303 ctx, cancel := context.WithCancel(context.Background())
2304 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2305 cst := newClientServerTest(t, mode, h)
2306
2307
2308 sendHi <- true
2309 res, err := cst.c.Get(cst.ts.URL)
2310 if err != nil {
2311 t.Error(err)
2312 }
2313 if g, e := res.StatusCode, StatusOK; g != e {
2314 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2315 }
2316 body, _ := io.ReadAll(res.Body)
2317 if g, e := string(body), "hi"; g != e {
2318 t.Errorf("got body %q; expected %q", g, e)
2319 }
2320 if g := <-writeErrors; g != nil {
2321 t.Errorf("got unexpected Write error on first request: %v", g)
2322 }
2323
2324
2325 cancel()
2326
2327 res, err = cst.c.Get(cst.ts.URL)
2328 if err != nil {
2329 t.Error(err)
2330 }
2331 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2332 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2333 }
2334 body, _ = io.ReadAll(res.Body)
2335 if !strings.Contains(string(body), "<title>Timeout</title>") {
2336 t.Errorf("expected timeout body; got %q", string(body))
2337 }
2338 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2339 t.Errorf("response content-type = %q; want %q", g, w)
2340 }
2341
2342
2343
2344 sendHi <- true
2345 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2346 t.Errorf("expected Write error of %v; got %v", e, g)
2347 }
2348 }
2349
2350
2351 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2352 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2353 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2354 ms, _ := strconv.Atoi(r.URL.Path[1:])
2355 if ms == 0 {
2356 ms = 1
2357 }
2358 for i := 0; i < ms; i++ {
2359 w.Write([]byte("hi"))
2360 time.Sleep(time.Millisecond)
2361 }
2362 })
2363
2364 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2365
2366 c := ts.Client()
2367
2368 var wg sync.WaitGroup
2369 gate := make(chan bool, 10)
2370 n := 50
2371 if testing.Short() {
2372 n = 10
2373 gate = make(chan bool, 3)
2374 }
2375 for i := 0; i < n; i++ {
2376 gate <- true
2377 wg.Add(1)
2378 go func() {
2379 defer wg.Done()
2380 defer func() { <-gate }()
2381 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2382 if err == nil {
2383 io.Copy(io.Discard, res.Body)
2384 res.Body.Close()
2385 }
2386 }()
2387 }
2388 wg.Wait()
2389 }
2390
2391
2392
2393 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2394 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2395 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2396 w.WriteHeader(204)
2397 })
2398
2399 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2400
2401 var wg sync.WaitGroup
2402 gate := make(chan bool, 50)
2403 n := 500
2404 if testing.Short() {
2405 n = 10
2406 }
2407
2408 c := ts.Client()
2409 for i := 0; i < n; i++ {
2410 gate <- true
2411 wg.Add(1)
2412 go func() {
2413 defer wg.Done()
2414 defer func() { <-gate }()
2415 res, err := c.Get(ts.URL)
2416 if err != nil {
2417
2418
2419 t.Log(err)
2420 return
2421 }
2422 defer res.Body.Close()
2423 io.Copy(io.Discard, res.Body)
2424 }()
2425 }
2426 wg.Wait()
2427 }
2428
2429
2430 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2431 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2432 sendHi := make(chan bool, 1)
2433 writeErrors := make(chan error, 1)
2434 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2435 w.Header().Set("Content-Type", "text/plain")
2436 <-sendHi
2437 _, werr := w.Write([]byte("hi"))
2438 writeErrors <- werr
2439 })
2440 ctx, cancel := context.WithCancel(context.Background())
2441 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2442 cst := newClientServerTest(t, mode, h)
2443
2444
2445 sendHi <- true
2446 res, err := cst.c.Get(cst.ts.URL)
2447 if err != nil {
2448 t.Error(err)
2449 }
2450 if g, e := res.StatusCode, StatusOK; g != e {
2451 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2452 }
2453 body, _ := io.ReadAll(res.Body)
2454 if g, e := string(body), "hi"; g != e {
2455 t.Errorf("got body %q; expected %q", g, e)
2456 }
2457 if g := <-writeErrors; g != nil {
2458 t.Errorf("got unexpected Write error on first request: %v", g)
2459 }
2460
2461
2462 cancel()
2463
2464 res, err = cst.c.Get(cst.ts.URL)
2465 if err != nil {
2466 t.Error(err)
2467 }
2468 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2469 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2470 }
2471 body, _ = io.ReadAll(res.Body)
2472 if !strings.Contains(string(body), "<title>Timeout</title>") {
2473 t.Errorf("expected timeout body; got %q", string(body))
2474 }
2475
2476
2477
2478 sendHi <- true
2479 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2480 t.Errorf("expected Write error of %v; got %v", e, g)
2481 }
2482 }
2483
2484
2485 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2486 run(t, testTimeoutHandlerStartTimerWhenServing)
2487 }
2488 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2489 if testing.Short() {
2490 t.Skip("skipping sleeping test in -short mode")
2491 }
2492 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2493 w.WriteHeader(StatusNoContent)
2494 }
2495 timeout := 300 * time.Millisecond
2496 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2497 defer ts.Close()
2498
2499 c := ts.Client()
2500
2501
2502
2503
2504 time.Sleep(2 * timeout)
2505 res, err := c.Get(ts.URL)
2506 if err != nil {
2507 t.Fatal(err)
2508 }
2509 defer res.Body.Close()
2510 if res.StatusCode != StatusNoContent {
2511 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2512 }
2513 }
2514
2515 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2516 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2517 writeErrors := make(chan error, 1)
2518 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2519 w.Header().Set("Content-Type", "text/plain")
2520 var err error
2521
2522
2523
2524 for i := 0; i < 100; i++ {
2525 _, err = w.Write([]byte("a"))
2526 if err != nil {
2527 break
2528 }
2529 time.Sleep(1 * time.Millisecond)
2530 }
2531 writeErrors <- err
2532 })
2533 ctx, cancel := context.WithCancel(context.Background())
2534 cancel()
2535 h := NewTestTimeoutHandler(sayHi, ctx)
2536 cst := newClientServerTest(t, mode, h)
2537 defer cst.close()
2538
2539 res, err := cst.c.Get(cst.ts.URL)
2540 if err != nil {
2541 t.Error(err)
2542 }
2543 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2544 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2545 }
2546 body, _ := io.ReadAll(res.Body)
2547 if g, e := string(body), ""; g != e {
2548 t.Errorf("got body %q; expected %q", g, e)
2549 }
2550 if g, e := <-writeErrors, context.Canceled; g != e {
2551 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2552 }
2553 }
2554
2555
2556 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2557 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2558 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2559
2560 }
2561 timeout := 300 * time.Millisecond
2562 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2563
2564 c := ts.Client()
2565
2566 res, err := c.Get(ts.URL)
2567 if err != nil {
2568 t.Fatal(err)
2569 }
2570 defer res.Body.Close()
2571 if res.StatusCode != StatusOK {
2572 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2573 }
2574 }
2575
2576
2577 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2578 wrapper := func(h Handler) Handler {
2579 return TimeoutHandler(h, time.Second, "")
2580 }
2581 run(t, func(t *testing.T, mode testMode) {
2582 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2583 }, testNotParallel)
2584 }
2585
2586 func TestRedirectBadPath(t *testing.T) {
2587
2588
2589 rr := httptest.NewRecorder()
2590 req := &Request{
2591 Method: "GET",
2592 URL: &url.URL{
2593 Scheme: "http",
2594 Path: "not-empty-but-no-leading-slash",
2595 },
2596 }
2597 Redirect(rr, req, "", 304)
2598 if rr.Code != 304 {
2599 t.Errorf("Code = %d; want 304", rr.Code)
2600 }
2601 }
2602
2603
2604 func TestRedirect(t *testing.T) {
2605 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2606
2607 var tests = []struct {
2608 in string
2609 want string
2610 }{
2611
2612 {"http://foobar.com/baz", "http://foobar.com/baz"},
2613
2614 {"https://foobar.com/baz", "https://foobar.com/baz"},
2615
2616 {"test://foobar.com/baz", "test://foobar.com/baz"},
2617
2618 {"//foobar.com/baz", "//foobar.com/baz"},
2619
2620 {"/foobar.com/baz", "/foobar.com/baz"},
2621
2622 {"foobar.com/baz", "/qux/foobar.com/baz"},
2623
2624 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2625
2626 {"///foobar.com/baz", "/foobar.com/baz"},
2627
2628
2629 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2630 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2631 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2632
2633 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2634 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2635 }
2636
2637 for _, tt := range tests {
2638 rec := httptest.NewRecorder()
2639 Redirect(rec, req, tt.in, 302)
2640 if got, want := rec.Code, 302; got != want {
2641 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2642 }
2643 if got := rec.Header().Get("Location"); got != tt.want {
2644 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2645 }
2646 }
2647 }
2648
2649
2650
2651 func TestRedirectContentTypeAndBody(t *testing.T) {
2652 type ctHeader struct {
2653 Values []string
2654 }
2655
2656 var tests = []struct {
2657 method string
2658 ct *ctHeader
2659 wantCT string
2660 wantBody string
2661 }{
2662 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2663 {MethodHead, nil, "text/html; charset=utf-8", ""},
2664 {MethodPost, nil, "", ""},
2665 {MethodDelete, nil, "", ""},
2666 {"foo", nil, "", ""},
2667 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2668 {MethodGet, &ctHeader{[]string{}}, "", ""},
2669 {MethodGet, &ctHeader{nil}, "", ""},
2670 }
2671 for _, tt := range tests {
2672 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2673 rec := httptest.NewRecorder()
2674 if tt.ct != nil {
2675 rec.Header()["Content-Type"] = tt.ct.Values
2676 }
2677 Redirect(rec, req, "/foo", 302)
2678 if got, want := rec.Code, 302; got != want {
2679 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2680 }
2681 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2682 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2683 }
2684 resp := rec.Result()
2685 body, err := io.ReadAll(resp.Body)
2686 if err != nil {
2687 t.Fatal(err)
2688 }
2689 if got, want := string(body), tt.wantBody; got != want {
2690 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2691 }
2692 }
2693 }
2694
2695
2696
2697
2698
2699
2700
2701 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2702
2703 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2704 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2705 all, err := io.ReadAll(r.Body)
2706 if err != nil {
2707 t.Fatalf("handler ReadAll: %v", err)
2708 }
2709 if len(all) != 0 {
2710 t.Errorf("handler got %d bytes; expected 0", len(all))
2711 }
2712 rw.Header().Set("Content-Length", "0")
2713 }))
2714
2715 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2716 if err != nil {
2717 t.Fatal(err)
2718 }
2719 req.ContentLength = 0
2720
2721 var resp [5]*Response
2722 for i := range resp {
2723 resp[i], err = cst.c.Do(req)
2724 if err != nil {
2725 t.Fatalf("client post #%d: %v", i, err)
2726 }
2727 }
2728
2729 for i := range resp {
2730 all, err := io.ReadAll(resp[i].Body)
2731 if err != nil {
2732 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2733 }
2734 if len(all) != 0 {
2735 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2736 }
2737 }
2738 }
2739
2740 func TestHandlerPanicNil(t *testing.T) {
2741 run(t, func(t *testing.T, mode testMode) {
2742 testHandlerPanic(t, false, mode, nil, nil)
2743 }, testNotParallel)
2744 }
2745
2746 func TestHandlerPanic(t *testing.T) {
2747 run(t, func(t *testing.T, mode testMode) {
2748 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2749 }, testNotParallel)
2750 }
2751
2752 func TestHandlerPanicWithHijack(t *testing.T) {
2753
2754 run(t, func(t *testing.T, mode testMode) {
2755 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2756 }, []testMode{http1Mode})
2757 }
2758
2759 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2760
2761
2762
2763
2764
2765
2766
2767
2768 pr, pw := io.Pipe()
2769 defer pw.Close()
2770
2771 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2772 if withHijack {
2773 rwc, _, err := w.(Hijacker).Hijack()
2774 if err != nil {
2775 t.Logf("unexpected error: %v", err)
2776 }
2777 defer rwc.Close()
2778 }
2779 panic(panicValue)
2780 })
2781 if wrapper != nil {
2782 handler = wrapper(handler)
2783 }
2784 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2785 ts.Config.ErrorLog = log.New(pw, "", 0)
2786 })
2787
2788
2789 done := make(chan bool, 1)
2790 go func() {
2791 buf := make([]byte, 4<<10)
2792 _, err := pr.Read(buf)
2793 pr.Close()
2794 if err != nil && err != io.EOF {
2795 t.Error(err)
2796 }
2797 done <- true
2798 }()
2799
2800 _, err := cst.c.Get(cst.ts.URL)
2801 if err == nil {
2802 t.Logf("expected an error")
2803 }
2804
2805 if panicValue == nil {
2806 return
2807 }
2808
2809 <-done
2810 }
2811
2812 type terrorWriter struct{ t *testing.T }
2813
2814 func (w terrorWriter) Write(p []byte) (int, error) {
2815 w.t.Errorf("%s", p)
2816 return len(p), nil
2817 }
2818
2819
2820
2821 func TestServerWriteHijackZeroBytes(t *testing.T) {
2822 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
2823 }
2824 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
2825 done := make(chan struct{})
2826 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2827 defer close(done)
2828 w.(Flusher).Flush()
2829 conn, _, err := w.(Hijacker).Hijack()
2830 if err != nil {
2831 t.Errorf("Hijack: %v", err)
2832 return
2833 }
2834 defer conn.Close()
2835 _, err = w.Write(nil)
2836 if err != ErrHijacked {
2837 t.Errorf("Write error = %v; want ErrHijacked", err)
2838 }
2839 }), func(ts *httptest.Server) {
2840 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
2841 }).ts
2842
2843 c := ts.Client()
2844 res, err := c.Get(ts.URL)
2845 if err != nil {
2846 t.Fatal(err)
2847 }
2848 res.Body.Close()
2849 <-done
2850 }
2851
2852 func TestServerNoDate(t *testing.T) {
2853 run(t, func(t *testing.T, mode testMode) {
2854 testServerNoHeader(t, mode, "Date")
2855 })
2856 }
2857
2858 func TestServerContentType(t *testing.T) {
2859 run(t, func(t *testing.T, mode testMode) {
2860 testServerNoHeader(t, mode, "Content-Type")
2861 })
2862 }
2863
2864 func testServerNoHeader(t *testing.T, mode testMode, header string) {
2865 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2866 w.Header()[header] = nil
2867 io.WriteString(w, "<html>foo</html>")
2868 }))
2869 res, err := cst.c.Get(cst.ts.URL)
2870 if err != nil {
2871 t.Fatal(err)
2872 }
2873 res.Body.Close()
2874 if got, ok := res.Header[header]; ok {
2875 t.Fatalf("Expected no %s header; got %q", header, got)
2876 }
2877 }
2878
2879 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
2880 func testStripPrefix(t *testing.T, mode testMode) {
2881 h := HandlerFunc(func(w ResponseWriter, r *Request) {
2882 w.Header().Set("X-Path", r.URL.Path)
2883 w.Header().Set("X-RawPath", r.URL.RawPath)
2884 })
2885 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
2886
2887 c := ts.Client()
2888
2889 cases := []struct {
2890 reqPath string
2891 path string
2892 rawPath string
2893 }{
2894 {"/foo/bar/qux", "/qux", ""},
2895 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
2896 {"/foo%2Fbar/qux", "", ""},
2897 {"/bar", "", ""},
2898 }
2899 for _, tc := range cases {
2900 t.Run(tc.reqPath, func(t *testing.T) {
2901 res, err := c.Get(ts.URL + tc.reqPath)
2902 if err != nil {
2903 t.Fatal(err)
2904 }
2905 res.Body.Close()
2906 if tc.path == "" {
2907 if res.StatusCode != StatusNotFound {
2908 t.Errorf("got %q, want 404 Not Found", res.Status)
2909 }
2910 return
2911 }
2912 if res.StatusCode != StatusOK {
2913 t.Fatalf("got %q, want 200 OK", res.Status)
2914 }
2915 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
2916 t.Errorf("got Path %q, want %q", g, w)
2917 }
2918 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
2919 t.Errorf("got RawPath %q, want %q", g, w)
2920 }
2921 })
2922 }
2923 }
2924
2925
2926 func TestStripPrefixNotModifyRequest(t *testing.T) {
2927 h := StripPrefix("/foo", NotFoundHandler())
2928 req := httptest.NewRequest("GET", "/foo/bar", nil)
2929 h.ServeHTTP(httptest.NewRecorder(), req)
2930 if req.URL.Path != "/foo/bar" {
2931 t.Errorf("StripPrefix should not modify the provided Request, but it did")
2932 }
2933 }
2934
2935 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
2936 func testRequestLimit(t *testing.T, mode testMode) {
2937 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2938 t.Fatalf("didn't expect to get request in Handler")
2939 }), optQuietLog)
2940 req, _ := NewRequest("GET", cst.ts.URL, nil)
2941 var bytesPerHeader = len("header12345: val12345\r\n")
2942 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
2943 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
2944 }
2945 res, err := cst.c.Do(req)
2946 if res != nil {
2947 defer res.Body.Close()
2948 }
2949 if mode == http2Mode {
2950
2951
2952
2953
2954 if err == nil && res.StatusCode != 431 {
2955 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2956 }
2957 } else {
2958
2959
2960
2961
2962 if err != nil {
2963 t.Fatalf("Do: %v", err)
2964 }
2965 if res.StatusCode != 431 {
2966 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2967 }
2968 }
2969 }
2970
2971 type neverEnding byte
2972
2973 func (b neverEnding) Read(p []byte) (n int, err error) {
2974 for i := range p {
2975 p[i] = byte(b)
2976 }
2977 return len(p), nil
2978 }
2979
2980 type bodyLimitReader struct {
2981 mu sync.Mutex
2982 count int
2983 limit int
2984 closed chan struct{}
2985 }
2986
2987 func (r *bodyLimitReader) Read(p []byte) (int, error) {
2988 r.mu.Lock()
2989 defer r.mu.Unlock()
2990 select {
2991 case <-r.closed:
2992 return 0, errors.New("closed")
2993 default:
2994 }
2995 if r.count > r.limit {
2996 return 0, errors.New("at limit")
2997 }
2998 r.count += len(p)
2999 for i := range p {
3000 p[i] = 'a'
3001 }
3002 return len(p), nil
3003 }
3004
3005 func (r *bodyLimitReader) Close() error {
3006 r.mu.Lock()
3007 defer r.mu.Unlock()
3008 close(r.closed)
3009 return nil
3010 }
3011
3012 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3013 func testRequestBodyLimit(t *testing.T, mode testMode) {
3014 const limit = 1 << 20
3015 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3016 r.Body = MaxBytesReader(w, r.Body, limit)
3017 n, err := io.Copy(io.Discard, r.Body)
3018 if err == nil {
3019 t.Errorf("expected error from io.Copy")
3020 }
3021 if n != limit {
3022 t.Errorf("io.Copy = %d, want %d", n, limit)
3023 }
3024 mbErr, ok := err.(*MaxBytesError)
3025 if !ok {
3026 t.Errorf("expected MaxBytesError, got %T", err)
3027 }
3028 if mbErr.Limit != limit {
3029 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3030 }
3031 }))
3032
3033 body := &bodyLimitReader{
3034 closed: make(chan struct{}),
3035 limit: limit * 200,
3036 }
3037 req, _ := NewRequest("POST", cst.ts.URL, body)
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048 resp, err := cst.c.Do(req)
3049 if err == nil {
3050 resp.Body.Close()
3051 }
3052
3053
3054 <-body.closed
3055
3056 if body.count > limit*100 {
3057 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3058 limit, body.count)
3059 }
3060 }
3061
3062
3063
3064 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3065 func testClientWriteShutdown(t *testing.T, mode testMode) {
3066 if runtime.GOOS == "plan9" {
3067 t.Skip("skipping test; see https://golang.org/issue/17906")
3068 }
3069 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3070 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3071 if err != nil {
3072 t.Fatalf("Dial: %v", err)
3073 }
3074 err = conn.(*net.TCPConn).CloseWrite()
3075 if err != nil {
3076 t.Fatalf("CloseWrite: %v", err)
3077 }
3078
3079 bs, err := io.ReadAll(conn)
3080 if err != nil {
3081 t.Errorf("ReadAll: %v", err)
3082 }
3083 got := string(bs)
3084 if got != "" {
3085 t.Errorf("read %q from server; want nothing", got)
3086 }
3087 }
3088
3089
3090
3091 func TestServerBufferedChunking(t *testing.T) {
3092 conn := new(testConn)
3093 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3094 conn.closec = make(chan bool, 1)
3095 ls := &oneConnListener{conn}
3096 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3097 rw.(Flusher).Flush()
3098 rw.Write([]byte{'x'})
3099 rw.Write([]byte{'y'})
3100 rw.Write([]byte{'z'})
3101 }))
3102 <-conn.closec
3103 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3104 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3105 conn.writeBuf.Bytes())
3106 }
3107 }
3108
3109
3110
3111
3112
3113 func TestServerGracefulClose(t *testing.T) {
3114
3115 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3116 }
3117 func testServerGracefulClose(t *testing.T, mode testMode) {
3118 runTimeSensitiveTest(t, []time.Duration{
3119 1 * time.Millisecond,
3120 5 * time.Millisecond,
3121 10 * time.Millisecond,
3122 50 * time.Millisecond,
3123 100 * time.Millisecond,
3124 500 * time.Millisecond,
3125 time.Second,
3126 5 * time.Second,
3127 }, func(t *testing.T, timeout time.Duration) error {
3128 SetRSTAvoidanceDelay(t, timeout)
3129 t.Logf("set RST avoidance delay to %v", timeout)
3130
3131 const bodySize = 5 << 20
3132 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3133 for i := 0; i < bodySize; i++ {
3134 req = append(req, 'x')
3135 }
3136
3137 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3138 Error(w, "bye", StatusUnauthorized)
3139 }))
3140
3141
3142 defer cst.close()
3143 ts := cst.ts
3144
3145 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3146 if err != nil {
3147 return err
3148 }
3149 writeErr := make(chan error)
3150 go func() {
3151 _, err := conn.Write(req)
3152 writeErr <- err
3153 }()
3154 defer func() {
3155 conn.Close()
3156
3157
3158
3159 <-writeErr
3160 }()
3161
3162 br := bufio.NewReader(conn)
3163 lineNum := 0
3164 for {
3165 line, err := br.ReadString('\n')
3166 if err == io.EOF {
3167 break
3168 }
3169 if err != nil {
3170 return fmt.Errorf("ReadLine: %v", err)
3171 }
3172 lineNum++
3173 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3174 t.Errorf("Response line = %q; want a 401", line)
3175 }
3176 }
3177 return nil
3178 })
3179 }
3180
3181 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3182 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3183 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3184 if r.Method != "get" {
3185 t.Errorf(`Got method %q; want "get"`, r.Method)
3186 }
3187 }))
3188 defer cst.close()
3189 req, _ := NewRequest("get", cst.ts.URL, nil)
3190 res, err := cst.c.Do(req)
3191 if err != nil {
3192 t.Error(err)
3193 return
3194 }
3195
3196 res.Body.Close()
3197 }
3198
3199
3200
3201
3202
3203 func TestContentLengthZero(t *testing.T) {
3204 run(t, testContentLengthZero, []testMode{http1Mode})
3205 }
3206 func testContentLengthZero(t *testing.T, mode testMode) {
3207 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3208
3209 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3210 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3211 if err != nil {
3212 t.Fatalf("error dialing: %v", err)
3213 }
3214 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3215 if err != nil {
3216 t.Fatalf("error writing: %v", err)
3217 }
3218 req, _ := NewRequest("GET", "/", nil)
3219 res, err := ReadResponse(bufio.NewReader(conn), req)
3220 if err != nil {
3221 t.Fatalf("error reading response: %v", err)
3222 }
3223 if te := res.TransferEncoding; len(te) > 0 {
3224 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3225 }
3226 if cl := res.ContentLength; cl != 0 {
3227 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3228 }
3229 conn.Close()
3230 }
3231 }
3232
3233 func TestCloseNotifier(t *testing.T) {
3234 run(t, testCloseNotifier, []testMode{http1Mode})
3235 }
3236 func testCloseNotifier(t *testing.T, mode testMode) {
3237 gotReq := make(chan bool, 1)
3238 sawClose := make(chan bool, 1)
3239 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3240 gotReq <- true
3241 cc := rw.(CloseNotifier).CloseNotify()
3242 <-cc
3243 sawClose <- true
3244 })).ts
3245 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3246 if err != nil {
3247 t.Fatalf("error dialing: %v", err)
3248 }
3249 diec := make(chan bool)
3250 go func() {
3251 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3252 if err != nil {
3253 t.Error(err)
3254 return
3255 }
3256 <-diec
3257 conn.Close()
3258 }()
3259 For:
3260 for {
3261 select {
3262 case <-gotReq:
3263 diec <- true
3264 case <-sawClose:
3265 break For
3266 }
3267 }
3268 ts.Close()
3269 }
3270
3271
3272
3273
3274
3275 func TestCloseNotifierPipelined(t *testing.T) {
3276 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3277 }
3278 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3279 gotReq := make(chan bool, 2)
3280 sawClose := make(chan bool, 2)
3281 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3282 gotReq <- true
3283 cc := rw.(CloseNotifier).CloseNotify()
3284 select {
3285 case <-cc:
3286 t.Error("unexpected CloseNotify")
3287 case <-time.After(100 * time.Millisecond):
3288 }
3289 sawClose <- true
3290 })).ts
3291 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3292 if err != nil {
3293 t.Fatalf("error dialing: %v", err)
3294 }
3295 diec := make(chan bool, 1)
3296 defer close(diec)
3297 go func() {
3298 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3299 _, err = io.WriteString(conn, req+req)
3300 if err != nil {
3301 t.Error(err)
3302 return
3303 }
3304 <-diec
3305 conn.Close()
3306 }()
3307 reqs := 0
3308 closes := 0
3309 for {
3310 select {
3311 case <-gotReq:
3312 reqs++
3313 if reqs > 2 {
3314 t.Fatal("too many requests")
3315 }
3316 case <-sawClose:
3317 closes++
3318 if closes > 1 {
3319 return
3320 }
3321 }
3322 }
3323 }
3324
3325 func TestCloseNotifierChanLeak(t *testing.T) {
3326 defer afterTest(t)
3327 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3328 for i := 0; i < 20; i++ {
3329 var output bytes.Buffer
3330 conn := &rwTestConn{
3331 Reader: bytes.NewReader(req),
3332 Writer: &output,
3333 closec: make(chan bool, 1),
3334 }
3335 ln := &oneConnListener{conn: conn}
3336 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3337
3338
3339
3340 _ = rw.(CloseNotifier).CloseNotify()
3341 })
3342 go Serve(ln, handler)
3343 <-conn.closec
3344 }
3345 }
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356 func TestHijackAfterCloseNotifier(t *testing.T) {
3357 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3358 }
3359 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3360 script := make(chan string, 2)
3361 script <- "closenotify"
3362 script <- "hijack"
3363 close(script)
3364 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3365 plan := <-script
3366 switch plan {
3367 default:
3368 panic("bogus plan; too many requests")
3369 case "closenotify":
3370 w.(CloseNotifier).CloseNotify()
3371 w.Header().Set("X-Addr", r.RemoteAddr)
3372 case "hijack":
3373 c, _, err := w.(Hijacker).Hijack()
3374 if err != nil {
3375 t.Errorf("Hijack in Handler: %v", err)
3376 return
3377 }
3378 if _, ok := c.(*net.TCPConn); !ok {
3379
3380
3381 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3382 }
3383 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3384 c.Close()
3385 return
3386 }
3387 })).ts
3388 res1, err := ts.Client().Get(ts.URL)
3389 if err != nil {
3390 log.Fatal(err)
3391 }
3392 res2, err := ts.Client().Get(ts.URL)
3393 if err != nil {
3394 log.Fatal(err)
3395 }
3396 addr1 := res1.Header.Get("X-Addr")
3397 addr2 := res2.Header.Get("X-Addr")
3398 if addr1 == "" || addr1 != addr2 {
3399 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3400 }
3401 }
3402
3403 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3404 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3405 }
3406 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3407 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3408 bodyOkay := make(chan bool, 1)
3409 gotCloseNotify := make(chan bool, 1)
3410 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3411 defer close(bodyOkay)
3412
3413 reqBody := r.Body
3414 r.Body = nil
3415
3416 gone := w.(CloseNotifier).CloseNotify()
3417 slurp, err := io.ReadAll(reqBody)
3418 if err != nil {
3419 t.Errorf("Body read: %v", err)
3420 return
3421 }
3422 if len(slurp) != len(requestBody) {
3423 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3424 return
3425 }
3426 if !bytes.Equal(slurp, requestBody) {
3427 t.Error("Backend read wrong request body.")
3428 return
3429 }
3430 bodyOkay <- true
3431 <-gone
3432 gotCloseNotify <- true
3433 })).ts
3434
3435 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3436 if err != nil {
3437 t.Fatal(err)
3438 }
3439 defer conn.Close()
3440
3441 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3442 len(requestBody), requestBody)
3443 if !<-bodyOkay {
3444
3445 return
3446 }
3447 conn.Close()
3448 <-gotCloseNotify
3449 }
3450
3451 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3452 func testOptions(t *testing.T, mode testMode) {
3453 uric := make(chan string, 2)
3454 mux := NewServeMux()
3455 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3456 uric <- r.RequestURI
3457 })
3458 ts := newClientServerTest(t, mode, mux).ts
3459
3460 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3461 if err != nil {
3462 t.Fatal(err)
3463 }
3464 defer conn.Close()
3465
3466
3467 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3468 if err != nil {
3469 t.Fatal(err)
3470 }
3471 br := bufio.NewReader(conn)
3472 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3473 if err != nil {
3474 t.Fatal(err)
3475 }
3476 if res.StatusCode != 200 {
3477 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3478 }
3479
3480
3481 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3482 if err != nil {
3483 t.Fatal(err)
3484 }
3485 res, err = ReadResponse(br, &Request{Method: "GET"})
3486 if err != nil {
3487 t.Fatal(err)
3488 }
3489 if res.StatusCode != 400 {
3490 t.Errorf("Got non-400 response to GET *: %#v", res)
3491 }
3492
3493 res, err = Get(ts.URL + "/second")
3494 if err != nil {
3495 t.Fatal(err)
3496 }
3497 res.Body.Close()
3498 if got := <-uric; got != "/second" {
3499 t.Errorf("Handler saw request for %q; want /second", got)
3500 }
3501 }
3502
3503 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3504 func testOptionsHandler(t *testing.T, mode testMode) {
3505 rc := make(chan *Request, 1)
3506
3507 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3508 rc <- r
3509 }), func(ts *httptest.Server) {
3510 ts.Config.DisableGeneralOptionsHandler = true
3511 }).ts
3512
3513 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3514 if err != nil {
3515 t.Fatal(err)
3516 }
3517 defer conn.Close()
3518
3519 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3520 if err != nil {
3521 t.Fatal(err)
3522 }
3523
3524 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3525 t.Errorf("Expected OPTIONS * request, got %v", got)
3526 }
3527 }
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538 func TestHeaderToWire(t *testing.T) {
3539 tests := []struct {
3540 name string
3541 handler func(ResponseWriter, *Request)
3542 check func(got, logs string) error
3543 }{
3544 {
3545 name: "write without Header",
3546 handler: func(rw ResponseWriter, r *Request) {
3547 rw.Write([]byte("hello world"))
3548 },
3549 check: func(got, logs string) error {
3550 if !strings.Contains(got, "Content-Length:") {
3551 return errors.New("no content-length")
3552 }
3553 if !strings.Contains(got, "Content-Type: text/plain") {
3554 return errors.New("no content-type")
3555 }
3556 return nil
3557 },
3558 },
3559 {
3560 name: "Header mutation before write",
3561 handler: func(rw ResponseWriter, r *Request) {
3562 h := rw.Header()
3563 h.Set("Content-Type", "some/type")
3564 rw.Write([]byte("hello world"))
3565 h.Set("Too-Late", "bogus")
3566 },
3567 check: func(got, logs string) error {
3568 if !strings.Contains(got, "Content-Length:") {
3569 return errors.New("no content-length")
3570 }
3571 if !strings.Contains(got, "Content-Type: some/type") {
3572 return errors.New("wrong content-type")
3573 }
3574 if strings.Contains(got, "Too-Late") {
3575 return errors.New("don't want too-late header")
3576 }
3577 return nil
3578 },
3579 },
3580 {
3581 name: "write then useless Header mutation",
3582 handler: func(rw ResponseWriter, r *Request) {
3583 rw.Write([]byte("hello world"))
3584 rw.Header().Set("Too-Late", "Write already wrote headers")
3585 },
3586 check: func(got, logs string) error {
3587 if strings.Contains(got, "Too-Late") {
3588 return errors.New("header appeared from after WriteHeader")
3589 }
3590 return nil
3591 },
3592 },
3593 {
3594 name: "flush then write",
3595 handler: func(rw ResponseWriter, r *Request) {
3596 rw.(Flusher).Flush()
3597 rw.Write([]byte("post-flush"))
3598 rw.Header().Set("Too-Late", "Write already wrote headers")
3599 },
3600 check: func(got, logs string) error {
3601 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3602 return errors.New("not chunked")
3603 }
3604 if strings.Contains(got, "Too-Late") {
3605 return errors.New("header appeared from after WriteHeader")
3606 }
3607 return nil
3608 },
3609 },
3610 {
3611 name: "header then flush",
3612 handler: func(rw ResponseWriter, r *Request) {
3613 rw.Header().Set("Content-Type", "some/type")
3614 rw.(Flusher).Flush()
3615 rw.Write([]byte("post-flush"))
3616 rw.Header().Set("Too-Late", "Write already wrote headers")
3617 },
3618 check: func(got, logs string) error {
3619 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3620 return errors.New("not chunked")
3621 }
3622 if strings.Contains(got, "Too-Late") {
3623 return errors.New("header appeared from after WriteHeader")
3624 }
3625 if !strings.Contains(got, "Content-Type: some/type") {
3626 return errors.New("wrong content-type")
3627 }
3628 return nil
3629 },
3630 },
3631 {
3632 name: "sniff-on-first-write content-type",
3633 handler: func(rw ResponseWriter, r *Request) {
3634 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3635 rw.Header().Set("Content-Type", "x/wrong")
3636 },
3637 check: func(got, logs string) error {
3638 if !strings.Contains(got, "Content-Type: text/html") {
3639 return errors.New("wrong content-type; want html")
3640 }
3641 return nil
3642 },
3643 },
3644 {
3645 name: "explicit content-type wins",
3646 handler: func(rw ResponseWriter, r *Request) {
3647 rw.Header().Set("Content-Type", "some/type")
3648 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3649 },
3650 check: func(got, logs string) error {
3651 if !strings.Contains(got, "Content-Type: some/type") {
3652 return errors.New("wrong content-type; want html")
3653 }
3654 return nil
3655 },
3656 },
3657 {
3658 name: "empty handler",
3659 handler: func(rw ResponseWriter, r *Request) {
3660 },
3661 check: func(got, logs string) error {
3662 if !strings.Contains(got, "Content-Length: 0") {
3663 return errors.New("want 0 content-length")
3664 }
3665 return nil
3666 },
3667 },
3668 {
3669 name: "only Header, no write",
3670 handler: func(rw ResponseWriter, r *Request) {
3671 rw.Header().Set("Some-Header", "some-value")
3672 },
3673 check: func(got, logs string) error {
3674 if !strings.Contains(got, "Some-Header") {
3675 return errors.New("didn't get header")
3676 }
3677 return nil
3678 },
3679 },
3680 {
3681 name: "WriteHeader call",
3682 handler: func(rw ResponseWriter, r *Request) {
3683 rw.WriteHeader(404)
3684 rw.Header().Set("Too-Late", "some-value")
3685 },
3686 check: func(got, logs string) error {
3687 if !strings.Contains(got, "404") {
3688 return errors.New("wrong status")
3689 }
3690 if strings.Contains(got, "Too-Late") {
3691 return errors.New("shouldn't have seen Too-Late")
3692 }
3693 return nil
3694 },
3695 },
3696 }
3697 for _, tc := range tests {
3698 ht := newHandlerTest(HandlerFunc(tc.handler))
3699 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3700 logs := ht.logbuf.String()
3701 if err := tc.check(got, logs); err != nil {
3702 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3703 }
3704 }
3705 }
3706
3707 type errorListener struct {
3708 errs []error
3709 }
3710
3711 func (l *errorListener) Accept() (c net.Conn, err error) {
3712 if len(l.errs) == 0 {
3713 return nil, io.EOF
3714 }
3715 err = l.errs[0]
3716 l.errs = l.errs[1:]
3717 return
3718 }
3719
3720 func (l *errorListener) Close() error {
3721 return nil
3722 }
3723
3724 func (l *errorListener) Addr() net.Addr {
3725 return dummyAddr("test-address")
3726 }
3727
3728 func TestAcceptMaxFds(t *testing.T) {
3729 setParallel(t)
3730
3731 ln := &errorListener{[]error{
3732 &net.OpError{
3733 Op: "accept",
3734 Err: syscall.EMFILE,
3735 }}}
3736 server := &Server{
3737 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3738 ErrorLog: log.New(io.Discard, "", 0),
3739 }
3740 err := server.Serve(ln)
3741 if err != io.EOF {
3742 t.Errorf("got error %v, want EOF", err)
3743 }
3744 }
3745
3746 func TestWriteAfterHijack(t *testing.T) {
3747 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3748 var buf strings.Builder
3749 wrotec := make(chan bool, 1)
3750 conn := &rwTestConn{
3751 Reader: bytes.NewReader(req),
3752 Writer: &buf,
3753 closec: make(chan bool, 1),
3754 }
3755 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3756 conn, bufrw, err := rw.(Hijacker).Hijack()
3757 if err != nil {
3758 t.Error(err)
3759 return
3760 }
3761 go func() {
3762 bufrw.Write([]byte("[hijack-to-bufw]"))
3763 bufrw.Flush()
3764 conn.Write([]byte("[hijack-to-conn]"))
3765 conn.Close()
3766 wrotec <- true
3767 }()
3768 })
3769 ln := &oneConnListener{conn: conn}
3770 go Serve(ln, handler)
3771 <-conn.closec
3772 <-wrotec
3773 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3774 t.Errorf("wrote %q; want %q", g, w)
3775 }
3776 }
3777
3778 func TestDoubleHijack(t *testing.T) {
3779 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3780 var buf bytes.Buffer
3781 conn := &rwTestConn{
3782 Reader: bytes.NewReader(req),
3783 Writer: &buf,
3784 closec: make(chan bool, 1),
3785 }
3786 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3787 conn, _, err := rw.(Hijacker).Hijack()
3788 if err != nil {
3789 t.Error(err)
3790 return
3791 }
3792 _, _, err = rw.(Hijacker).Hijack()
3793 if err == nil {
3794 t.Errorf("got err = nil; want err != nil")
3795 }
3796 conn.Close()
3797 })
3798 ln := &oneConnListener{conn: conn}
3799 go Serve(ln, handler)
3800 <-conn.closec
3801 }
3802
3803
3804
3805
3806
3807
3808
3809 func TestHTTP10ConnectionHeader(t *testing.T) {
3810 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
3811 }
3812 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
3813 mux := NewServeMux()
3814 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
3815 ts := newClientServerTest(t, mode, mux).ts
3816
3817
3818 tests := []struct {
3819 req string
3820 expect []string
3821 }{
3822 {
3823 req: "GET / HTTP/1.0\r\n\r\n",
3824 expect: nil,
3825 },
3826 {
3827 req: "OPTIONS * HTTP/1.0\r\n\r\n",
3828 expect: nil,
3829 },
3830 {
3831 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
3832 expect: []string{"keep-alive"},
3833 },
3834 }
3835
3836 for _, tt := range tests {
3837 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3838 if err != nil {
3839 t.Fatal("dial err:", err)
3840 }
3841
3842 _, err = fmt.Fprint(conn, tt.req)
3843 if err != nil {
3844 t.Fatal("conn write err:", err)
3845 }
3846
3847 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
3848 if err != nil {
3849 t.Fatal("ReadResponse err:", err)
3850 }
3851 conn.Close()
3852 resp.Body.Close()
3853
3854 got := resp.Header["Connection"]
3855 if !reflect.DeepEqual(got, tt.expect) {
3856 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
3857 }
3858 }
3859 }
3860
3861
3862 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
3863 func testServerReaderFromOrder(t *testing.T, mode testMode) {
3864 pr, pw := io.Pipe()
3865 const size = 3 << 20
3866 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3867 rw.Header().Set("Content-Type", "text/plain")
3868 done := make(chan bool)
3869 go func() {
3870 io.Copy(rw, pr)
3871 close(done)
3872 }()
3873 time.Sleep(25 * time.Millisecond)
3874 n, err := io.Copy(io.Discard, req.Body)
3875 if err != nil {
3876 t.Errorf("handler Copy: %v", err)
3877 return
3878 }
3879 if n != size {
3880 t.Errorf("handler Copy = %d; want %d", n, size)
3881 }
3882 pw.Write([]byte("hi"))
3883 pw.Close()
3884 <-done
3885 }))
3886
3887 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
3888 if err != nil {
3889 t.Fatal(err)
3890 }
3891 res, err := cst.c.Do(req)
3892 if err != nil {
3893 t.Fatal(err)
3894 }
3895 all, err := io.ReadAll(res.Body)
3896 if err != nil {
3897 t.Fatal(err)
3898 }
3899 res.Body.Close()
3900 if string(all) != "hi" {
3901 t.Errorf("Body = %q; want hi", all)
3902 }
3903 }
3904
3905
3906 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
3907 for _, code := range []int{StatusNotModified, StatusNoContent} {
3908 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3909 if r.URL.Path == "/header" {
3910 w.Header().Set("Content-Length", "123")
3911 }
3912 w.WriteHeader(code)
3913 if r.URL.Path == "/more" {
3914 w.Write([]byte("stuff"))
3915 }
3916 }))
3917 for _, req := range []string{
3918 "GET / HTTP/1.0",
3919 "GET /header HTTP/1.0",
3920 "GET /more HTTP/1.0",
3921 "GET / HTTP/1.1\nHost: foo",
3922 "GET /header HTTP/1.1\nHost: foo",
3923 "GET /more HTTP/1.1\nHost: foo",
3924 } {
3925 got := ht.rawResponse(req)
3926 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
3927 if !strings.Contains(got, wantStatus) {
3928 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
3929 } else if strings.Contains(got, "Content-Length") {
3930 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
3931 } else if strings.Contains(got, "stuff") {
3932 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
3933 }
3934 }
3935 }
3936 }
3937
3938 func TestContentTypeOkayOn204(t *testing.T) {
3939 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3940 w.Header().Set("Content-Length", "123")
3941 w.Header().Set("Content-Type", "foo/bar")
3942 w.WriteHeader(204)
3943 }))
3944 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
3945 if !strings.Contains(got, "Content-Type: foo/bar") {
3946 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
3947 }
3948 if strings.Contains(got, "Content-Length: 123") {
3949 t.Errorf("Response = %q; don't want a Content-Length", got)
3950 }
3951 }
3952
3953
3954
3955
3956
3957
3958
3959 func TestTransportAndServerSharedBodyRace(t *testing.T) {
3960 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
3961 }
3962 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
3963
3964
3965
3966
3967 runTimeSensitiveTest(t, []time.Duration{
3968 1 * time.Millisecond,
3969 5 * time.Millisecond,
3970 10 * time.Millisecond,
3971 50 * time.Millisecond,
3972 100 * time.Millisecond,
3973 500 * time.Millisecond,
3974 time.Second,
3975 5 * time.Second,
3976 }, func(t *testing.T, timeout time.Duration) error {
3977 SetRSTAvoidanceDelay(t, timeout)
3978 t.Logf("set RST avoidance delay to %v", timeout)
3979
3980 const bodySize = 1 << 20
3981
3982 var wg sync.WaitGroup
3983 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3984
3985
3986
3987
3988
3989
3990
3991
3992 wg.Add(1)
3993 defer wg.Done()
3994
3995 n, err := io.CopyN(rw, req.Body, bodySize)
3996 t.Logf("backend CopyN: %v, %v", n, err)
3997 <-req.Context().Done()
3998 }))
3999
4000
4001 defer func() {
4002 wg.Wait()
4003 backend.close()
4004 }()
4005
4006 var proxy *clientServerTest
4007 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4008 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4009 req2.ContentLength = bodySize
4010 cancel := make(chan struct{})
4011 req2.Cancel = cancel
4012
4013 bresp, err := proxy.c.Do(req2)
4014 if err != nil {
4015 t.Errorf("Proxy outbound request: %v", err)
4016 return
4017 }
4018 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4019 if err != nil {
4020 t.Errorf("Proxy copy error: %v", err)
4021 return
4022 }
4023 t.Cleanup(func() { bresp.Body.Close() })
4024
4025
4026
4027
4028
4029
4030 if mode == http2Mode {
4031 close(cancel)
4032 } else {
4033 proxy.c.Transport.(*Transport).CancelRequest(req2)
4034 }
4035 rw.Write([]byte("OK"))
4036 }))
4037 defer proxy.close()
4038
4039 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4040 res, err := proxy.c.Do(req)
4041 if err != nil {
4042 return fmt.Errorf("original request: %v", err)
4043 }
4044 res.Body.Close()
4045 return nil
4046 })
4047 }
4048
4049
4050
4051
4052 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4053 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4054 }
4055 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4056 if testing.Short() {
4057 t.Skip("skipping in -short mode")
4058 }
4059
4060 readErrCh := make(chan error, 1)
4061 errCh := make(chan error, 2)
4062
4063 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4064 go func(body io.Reader) {
4065 _, err := body.Read(make([]byte, 100))
4066 readErrCh <- err
4067 }(req.Body)
4068 time.Sleep(500 * time.Millisecond)
4069 })).ts
4070
4071 closeConn := make(chan bool)
4072 defer close(closeConn)
4073 go func() {
4074 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4075 if err != nil {
4076 errCh <- err
4077 return
4078 }
4079 defer conn.Close()
4080 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4081 if err != nil {
4082 errCh <- err
4083 return
4084 }
4085
4086
4087 <-closeConn
4088 }()
4089 select {
4090 case err := <-readErrCh:
4091 if err == nil {
4092 t.Error("Read was nil. Expected error.")
4093 }
4094 case err := <-errCh:
4095 t.Error(err)
4096 }
4097 }
4098
4099
4100 func TestResponseWriterWriteString(t *testing.T) {
4101 okc := make(chan bool, 1)
4102 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4103 _, ok := w.(io.StringWriter)
4104 okc <- ok
4105 }))
4106 ht.rawResponse("GET / HTTP/1.0")
4107 select {
4108 case ok := <-okc:
4109 if !ok {
4110 t.Error("ResponseWriter did not implement io.StringWriter")
4111 }
4112 default:
4113 t.Error("handler was never called")
4114 }
4115 }
4116
4117 func TestAppendTime(t *testing.T) {
4118 var b [len(TimeFormat)]byte
4119 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4120 res := ExportAppendTime(b[:0], t1)
4121 t2, err := ParseTime(string(res))
4122 if err != nil {
4123 t.Fatalf("Error parsing time: %s", err)
4124 }
4125 if !t1.Equal(t2) {
4126 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4127 }
4128 }
4129
4130 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4131 func testServerConnState(t *testing.T, mode testMode) {
4132 handler := map[string]func(w ResponseWriter, r *Request){
4133 "/": func(w ResponseWriter, r *Request) {
4134 fmt.Fprintf(w, "Hello.")
4135 },
4136 "/close": func(w ResponseWriter, r *Request) {
4137 w.Header().Set("Connection", "close")
4138 fmt.Fprintf(w, "Hello.")
4139 },
4140 "/hijack": func(w ResponseWriter, r *Request) {
4141 c, _, _ := w.(Hijacker).Hijack()
4142 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4143 c.Close()
4144 },
4145 "/hijack-panic": func(w ResponseWriter, r *Request) {
4146 c, _, _ := w.(Hijacker).Hijack()
4147 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4148 c.Close()
4149 panic("intentional panic")
4150 },
4151 }
4152
4153
4154 type stateLog struct {
4155 active net.Conn
4156 got []ConnState
4157 want []ConnState
4158 complete chan<- struct{}
4159 }
4160 activeLog := make(chan *stateLog, 1)
4161
4162
4163
4164
4165 wantLog := func(doRequests func(), want ...ConnState) {
4166 t.Helper()
4167 complete := make(chan struct{})
4168 activeLog <- &stateLog{want: want, complete: complete}
4169
4170 doRequests()
4171
4172 <-complete
4173 sl := <-activeLog
4174 if !reflect.DeepEqual(sl.got, sl.want) {
4175 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4176 }
4177
4178
4179
4180 }
4181
4182 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4183 handler[r.URL.Path](w, r)
4184 }), func(ts *httptest.Server) {
4185 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4186 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4187 if c == nil {
4188 t.Errorf("nil conn seen in state %s", state)
4189 return
4190 }
4191 sl := <-activeLog
4192 if sl.active == nil && state == StateNew {
4193 sl.active = c
4194 } else if sl.active != c {
4195 t.Errorf("unexpected conn in state %s", state)
4196 activeLog <- sl
4197 return
4198 }
4199 sl.got = append(sl.got, state)
4200 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
4201 close(sl.complete)
4202 sl.complete = nil
4203 }
4204 activeLog <- sl
4205 }
4206 }).ts
4207 defer func() {
4208 activeLog <- &stateLog{}
4209 ts.Close()
4210 }()
4211
4212 c := ts.Client()
4213
4214 mustGet := func(url string, headers ...string) {
4215 t.Helper()
4216 req, err := NewRequest("GET", url, nil)
4217 if err != nil {
4218 t.Fatal(err)
4219 }
4220 for len(headers) > 0 {
4221 req.Header.Add(headers[0], headers[1])
4222 headers = headers[2:]
4223 }
4224 res, err := c.Do(req)
4225 if err != nil {
4226 t.Errorf("Error fetching %s: %v", url, err)
4227 return
4228 }
4229 _, err = io.ReadAll(res.Body)
4230 defer res.Body.Close()
4231 if err != nil {
4232 t.Errorf("Error reading %s: %v", url, err)
4233 }
4234 }
4235
4236 wantLog(func() {
4237 mustGet(ts.URL + "/")
4238 mustGet(ts.URL + "/close")
4239 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4240
4241 wantLog(func() {
4242 mustGet(ts.URL + "/")
4243 mustGet(ts.URL+"/", "Connection", "close")
4244 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4245
4246 wantLog(func() {
4247 mustGet(ts.URL + "/hijack")
4248 }, StateNew, StateActive, StateHijacked)
4249
4250 wantLog(func() {
4251 mustGet(ts.URL + "/hijack-panic")
4252 }, StateNew, StateActive, StateHijacked)
4253
4254 wantLog(func() {
4255 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4256 if err != nil {
4257 t.Fatal(err)
4258 }
4259 c.Close()
4260 }, StateNew, StateClosed)
4261
4262 wantLog(func() {
4263 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4264 if err != nil {
4265 t.Fatal(err)
4266 }
4267 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4268 t.Fatal(err)
4269 }
4270 c.Read(make([]byte, 1))
4271 c.Close()
4272 }, StateNew, StateActive, StateClosed)
4273
4274 wantLog(func() {
4275 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4276 if err != nil {
4277 t.Fatal(err)
4278 }
4279 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4280 t.Fatal(err)
4281 }
4282 res, err := ReadResponse(bufio.NewReader(c), nil)
4283 if err != nil {
4284 t.Fatal(err)
4285 }
4286 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4287 t.Fatal(err)
4288 }
4289 c.Close()
4290 }, StateNew, StateActive, StateIdle, StateClosed)
4291 }
4292
4293 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4294 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4295 }
4296 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4297 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4298 }), func(ts *httptest.Server) {
4299 ts.Config.SetKeepAlivesEnabled(false)
4300 }).ts
4301 res, err := ts.Client().Get(ts.URL)
4302 if err != nil {
4303 t.Fatal(err)
4304 }
4305 defer res.Body.Close()
4306 if !res.Close {
4307 t.Errorf("Body.Close == false; want true")
4308 }
4309 }
4310
4311
4312 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4313 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4314 var n int32
4315 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4316 atomic.AddInt32(&n, 1)
4317 }), optQuietLog)
4318 var wg sync.WaitGroup
4319 const reqs = 20
4320 for i := 0; i < reqs; i++ {
4321 wg.Add(1)
4322 go func() {
4323 defer wg.Done()
4324 res, err := cst.c.Get(cst.ts.URL)
4325 if err != nil {
4326
4327
4328 time.Sleep(10 * time.Millisecond)
4329 res, err = cst.c.Get(cst.ts.URL)
4330 if err != nil {
4331 t.Error(err)
4332 return
4333 }
4334 }
4335 defer res.Body.Close()
4336 _, err = io.Copy(io.Discard, res.Body)
4337 if err != nil {
4338 t.Error(err)
4339 return
4340 }
4341 }()
4342 }
4343 wg.Wait()
4344 if got := atomic.LoadInt32(&n); got != reqs {
4345 t.Errorf("handler ran %d times; want %d", got, reqs)
4346 }
4347 }
4348
4349 func TestServerConnStateNew(t *testing.T) {
4350 sawNew := false
4351 srv := &Server{
4352 ConnState: func(c net.Conn, state ConnState) {
4353 if state == StateNew {
4354 sawNew = true
4355 }
4356 },
4357 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4358 }
4359 srv.Serve(&oneConnListener{
4360 conn: &rwTestConn{
4361 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4362 Writer: io.Discard,
4363 },
4364 })
4365 if !sawNew {
4366 t.Error("StateNew not seen")
4367 }
4368 }
4369
4370 type closeWriteTestConn struct {
4371 rwTestConn
4372 didCloseWrite bool
4373 }
4374
4375 func (c *closeWriteTestConn) CloseWrite() error {
4376 c.didCloseWrite = true
4377 return nil
4378 }
4379
4380 func TestCloseWrite(t *testing.T) {
4381 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4382
4383 var srv Server
4384 var testConn closeWriteTestConn
4385 c := ExportServerNewConn(&srv, &testConn)
4386 ExportCloseWriteAndWait(c)
4387 if !testConn.didCloseWrite {
4388 t.Error("didn't see CloseWrite call")
4389 }
4390 }
4391
4392
4393
4394
4395
4396
4397
4398
4399 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4400 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4401 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4402 io.WriteString(w, "Hello, ")
4403 w.(Flusher).Flush()
4404 conn, buf, _ := w.(Hijacker).Hijack()
4405 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4406 if err := buf.Flush(); err != nil {
4407 t.Error(err)
4408 }
4409 if err := conn.Close(); err != nil {
4410 t.Error(err)
4411 }
4412 })).ts
4413 res, err := Get(ts.URL)
4414 if err != nil {
4415 t.Fatal(err)
4416 }
4417 defer res.Body.Close()
4418 all, err := io.ReadAll(res.Body)
4419 if err != nil {
4420 t.Fatal(err)
4421 }
4422 if want := "Hello, world!"; string(all) != want {
4423 t.Errorf("Got %q; want %q", all, want)
4424 }
4425 }
4426
4427
4428
4429
4430
4431
4432
4433 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4434 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4435 }
4436 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4437 if testing.Short() {
4438 t.Skip("skipping in -short mode")
4439 }
4440 const numReq = 3
4441 addrc := make(chan string, numReq)
4442 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4443 addrc <- r.RemoteAddr
4444 time.Sleep(500 * time.Millisecond)
4445 w.(Flusher).Flush()
4446 }), func(ts *httptest.Server) {
4447 ts.Config.WriteTimeout = 250 * time.Millisecond
4448 }).ts
4449
4450 errc := make(chan error, numReq)
4451 go func() {
4452 defer close(errc)
4453 for i := 0; i < numReq; i++ {
4454 res, err := Get(ts.URL)
4455 if res != nil {
4456 res.Body.Close()
4457 }
4458 errc <- err
4459 }
4460 }()
4461
4462 addrSeen := map[string]bool{}
4463 numOkay := 0
4464 for {
4465 select {
4466 case v := <-addrc:
4467 addrSeen[v] = true
4468 case err, ok := <-errc:
4469 if !ok {
4470 if len(addrSeen) != numReq {
4471 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4472 }
4473 if numOkay != 0 {
4474 t.Errorf("got %d successful client requests; want 0", numOkay)
4475 }
4476 return
4477 }
4478 if err == nil {
4479 numOkay++
4480 }
4481 }
4482 }
4483 }
4484
4485
4486
4487 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4488 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4489 }
4490 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4491 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4492 w.Header().Set("Transfer-Encoding", "foo")
4493 io.WriteString(w, "<html>")
4494 })).ts
4495 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4496 if err != nil {
4497 t.Fatalf("Dial: %v", err)
4498 }
4499 defer c.Close()
4500 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4501 t.Fatal(err)
4502 }
4503 bs := bufio.NewScanner(c)
4504 var got strings.Builder
4505 for bs.Scan() {
4506 if strings.TrimSpace(bs.Text()) == "" {
4507 break
4508 }
4509 got.WriteString(bs.Text())
4510 got.WriteByte('\n')
4511 }
4512 if err := bs.Err(); err != nil {
4513 t.Fatal(err)
4514 }
4515 if strings.Contains(got.String(), "Content-Length") {
4516 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4517 }
4518 if strings.Contains(got.String(), "Content-Type") {
4519 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4520 }
4521 }
4522
4523
4524
4525 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4526 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4527 "\r\n\r\n" +
4528 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4529 var buf bytes.Buffer
4530 conn := &rwTestConn{
4531 Reader: bytes.NewReader(req),
4532 Writer: &buf,
4533 closec: make(chan bool, 1),
4534 }
4535 ln := &oneConnListener{conn: conn}
4536 numReq := 0
4537 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4538 numReq++
4539 }))
4540 <-conn.closec
4541 if numReq != 2 {
4542 t.Errorf("num requests = %d; want 2", numReq)
4543 t.Logf("Res: %s", buf.Bytes())
4544 }
4545 }
4546
4547 func TestIssue13893_Expect100(t *testing.T) {
4548
4549 req := reqBytes(`PUT /readbody HTTP/1.1
4550 User-Agent: PycURL/7.22.0
4551 Host: 127.0.0.1:9000
4552 Accept: */*
4553 Expect: 100-continue
4554 Content-Length: 10
4555
4556 HelloWorld
4557
4558 `)
4559 var buf bytes.Buffer
4560 conn := &rwTestConn{
4561 Reader: bytes.NewReader(req),
4562 Writer: &buf,
4563 closec: make(chan bool, 1),
4564 }
4565 ln := &oneConnListener{conn: conn}
4566 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4567 if _, ok := r.Header["Expect"]; !ok {
4568 t.Error("Expect header should not be filtered out")
4569 }
4570 }))
4571 <-conn.closec
4572 }
4573
4574 func TestIssue11549_Expect100(t *testing.T) {
4575 req := reqBytes(`PUT /readbody HTTP/1.1
4576 User-Agent: PycURL/7.22.0
4577 Host: 127.0.0.1:9000
4578 Accept: */*
4579 Expect: 100-continue
4580 Content-Length: 10
4581
4582 HelloWorldPUT /noreadbody HTTP/1.1
4583 User-Agent: PycURL/7.22.0
4584 Host: 127.0.0.1:9000
4585 Accept: */*
4586 Expect: 100-continue
4587 Content-Length: 10
4588
4589 GET /should-be-ignored HTTP/1.1
4590 Host: foo
4591
4592 `)
4593 var buf strings.Builder
4594 conn := &rwTestConn{
4595 Reader: bytes.NewReader(req),
4596 Writer: &buf,
4597 closec: make(chan bool, 1),
4598 }
4599 ln := &oneConnListener{conn: conn}
4600 numReq := 0
4601 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4602 numReq++
4603 if r.URL.Path == "/readbody" {
4604 io.ReadAll(r.Body)
4605 }
4606 io.WriteString(w, "Hello world!")
4607 }))
4608 <-conn.closec
4609 if numReq != 2 {
4610 t.Errorf("num requests = %d; want 2", numReq)
4611 }
4612 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4613 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4614 }
4615 }
4616
4617
4618
4619 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4620 setParallel(t)
4621 conn := newTestConn()
4622 conn.readBuf.Write([]byte(fmt.Sprintf(
4623 "POST / HTTP/1.1\r\n" +
4624 "Host: test\r\n" +
4625 "Content-Length: 9999999999\r\n" +
4626 "\r\n" + strings.Repeat("a", 1<<20))))
4627
4628 ls := &oneConnListener{conn}
4629 var inHandlerLen int
4630 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4631 inHandlerLen = conn.readBuf.Len()
4632 rw.WriteHeader(404)
4633 }))
4634 <-conn.closec
4635 afterHandlerLen := conn.readBuf.Len()
4636
4637 if afterHandlerLen != inHandlerLen {
4638 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4639 }
4640 }
4641
4642 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4643 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4644 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4645 r.Body = nil
4646 fmt.Fprintf(w, "%v", r.RemoteAddr)
4647 }))
4648 get := func() string {
4649 res, err := cst.c.Get(cst.ts.URL)
4650 if err != nil {
4651 t.Fatal(err)
4652 }
4653 defer res.Body.Close()
4654 slurp, err := io.ReadAll(res.Body)
4655 if err != nil {
4656 t.Fatal(err)
4657 }
4658 return string(slurp)
4659 }
4660 a, b := get(), get()
4661 if a != b {
4662 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4663 }
4664 }
4665
4666
4667
4668 func TestServerValidatesHostHeader(t *testing.T) {
4669 tests := []struct {
4670 proto string
4671 host string
4672 want int
4673 }{
4674 {"HTTP/0.9", "", 505},
4675
4676 {"HTTP/1.1", "", 400},
4677 {"HTTP/1.1", "Host: \r\n", 200},
4678 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4679 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4680 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4681 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4682 {"HTTP/1.1", "Host: ::1\r\n", 200},
4683 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4684 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4685 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4686 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4687 {"HTTP/1.1", "Host: \x06\r\n", 400},
4688 {"HTTP/1.1", "Host: \xff\r\n", 400},
4689 {"HTTP/1.1", "Host: {\r\n", 400},
4690 {"HTTP/1.1", "Host: }\r\n", 400},
4691 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4692
4693
4694
4695 {"HTTP/1.0", "", 200},
4696 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4697 {"HTTP/1.0", "Host: \xff\r\n", 400},
4698
4699
4700 {"PRI * HTTP/2.0", "", 200},
4701
4702
4703 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4704
4705
4706 {"PRI / HTTP/2.0", "", 505},
4707 {"GET / HTTP/2.0", "", 505},
4708 {"GET / HTTP/3.0", "", 505},
4709 }
4710 for _, tt := range tests {
4711 conn := newTestConn()
4712 methodTarget := "GET / "
4713 if !strings.HasPrefix(tt.proto, "HTTP/") {
4714 methodTarget = ""
4715 }
4716 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4717
4718 ln := &oneConnListener{conn}
4719 srv := Server{
4720 ErrorLog: quietLog,
4721 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4722 }
4723 go srv.Serve(ln)
4724 <-conn.closec
4725 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4726 if err != nil {
4727 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4728 continue
4729 }
4730 if res.StatusCode != tt.want {
4731 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4732 }
4733 }
4734 }
4735
4736 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4737 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4738 }
4739 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4740 const upgradeResponse = "upgrade here"
4741 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4742 conn, br, err := w.(Hijacker).Hijack()
4743 if err != nil {
4744 t.Error(err)
4745 return
4746 }
4747 defer conn.Close()
4748 if r.Method != "PRI" || r.RequestURI != "*" {
4749 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4750 return
4751 }
4752 if !r.Close {
4753 t.Errorf("Request.Close = true; want false")
4754 }
4755 const want = "SM\r\n\r\n"
4756 buf := make([]byte, len(want))
4757 n, err := io.ReadFull(br, buf)
4758 if err != nil || string(buf[:n]) != want {
4759 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4760 return
4761 }
4762 io.WriteString(conn, upgradeResponse)
4763 })).ts
4764
4765 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4766 if err != nil {
4767 t.Fatalf("Dial: %v", err)
4768 }
4769 defer c.Close()
4770 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4771 slurp, err := io.ReadAll(c)
4772 if err != nil {
4773 t.Fatal(err)
4774 }
4775 if string(slurp) != upgradeResponse {
4776 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4777 }
4778 }
4779
4780
4781
4782 func TestServerValidatesHeaders(t *testing.T) {
4783 setParallel(t)
4784 tests := []struct {
4785 header string
4786 want int
4787 }{
4788 {"", 200},
4789 {"Foo: bar\r\n", 200},
4790 {"X-Foo: bar\r\n", 200},
4791 {"Foo: a space\r\n", 200},
4792
4793 {"A space: foo\r\n", 400},
4794 {"foo\xffbar: foo\r\n", 400},
4795 {"foo\x00bar: foo\r\n", 400},
4796 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4797
4798
4799 {"Foo : bar\r\n", 400},
4800 {"Foo\t: bar\r\n", 400},
4801
4802 {"foo: foo foo\r\n", 200},
4803 {"foo: foo\tfoo\r\n", 200},
4804 {"foo: foo\x00foo\r\n", 400},
4805 {"foo: foo\x7ffoo\r\n", 400},
4806 {"foo: foo\xfffoo\r\n", 200},
4807 }
4808 for _, tt := range tests {
4809 conn := newTestConn()
4810 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4811
4812 ln := &oneConnListener{conn}
4813 srv := Server{
4814 ErrorLog: quietLog,
4815 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4816 }
4817 go srv.Serve(ln)
4818 <-conn.closec
4819 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4820 if err != nil {
4821 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
4822 continue
4823 }
4824 if res.StatusCode != tt.want {
4825 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
4826 }
4827 }
4828 }
4829
4830 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
4831 run(t, testServerRequestContextCancel_ServeHTTPDone)
4832 }
4833 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
4834 ctxc := make(chan context.Context, 1)
4835 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4836 ctx := r.Context()
4837 select {
4838 case <-ctx.Done():
4839 t.Error("should not be Done in ServeHTTP")
4840 default:
4841 }
4842 ctxc <- ctx
4843 }))
4844 res, err := cst.c.Get(cst.ts.URL)
4845 if err != nil {
4846 t.Fatal(err)
4847 }
4848 res.Body.Close()
4849 ctx := <-ctxc
4850 select {
4851 case <-ctx.Done():
4852 default:
4853 t.Error("context should be done after ServeHTTP completes")
4854 }
4855 }
4856
4857
4858
4859
4860
4861 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
4862 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
4863 }
4864 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
4865 inHandler := make(chan struct{})
4866 handlerDone := make(chan struct{})
4867 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4868 close(inHandler)
4869 <-r.Context().Done()
4870 close(handlerDone)
4871 })).ts
4872 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4873 if err != nil {
4874 t.Fatal(err)
4875 }
4876 defer c.Close()
4877 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
4878 <-inHandler
4879 c.Close()
4880 <-handlerDone
4881 }
4882
4883 func TestServerContext_ServerContextKey(t *testing.T) {
4884 run(t, testServerContext_ServerContextKey)
4885 }
4886 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
4887 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4888 ctx := r.Context()
4889 got := ctx.Value(ServerContextKey)
4890 if _, ok := got.(*Server); !ok {
4891 t.Errorf("context value = %T; want *http.Server", got)
4892 }
4893 }))
4894 res, err := cst.c.Get(cst.ts.URL)
4895 if err != nil {
4896 t.Fatal(err)
4897 }
4898 res.Body.Close()
4899 }
4900
4901 func TestServerContext_LocalAddrContextKey(t *testing.T) {
4902 run(t, testServerContext_LocalAddrContextKey)
4903 }
4904 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
4905 ch := make(chan any, 1)
4906 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4907 ch <- r.Context().Value(LocalAddrContextKey)
4908 }))
4909 if _, err := cst.c.Head(cst.ts.URL); err != nil {
4910 t.Fatal(err)
4911 }
4912
4913 host := cst.ts.Listener.Addr().String()
4914 got := <-ch
4915 if addr, ok := got.(net.Addr); !ok {
4916 t.Errorf("local addr value = %T; want net.Addr", got)
4917 } else if fmt.Sprint(addr) != host {
4918 t.Errorf("local addr = %v; want %v", addr, host)
4919 }
4920 }
4921
4922
4923 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
4924 setParallel(t)
4925 defer afterTest(t)
4926 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4927 w.Header().Set("Transfer-Encoding", "chunked")
4928 w.Write([]byte("hello"))
4929 }))
4930 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4931 const hdr = "Transfer-Encoding: chunked"
4932 if n := strings.Count(resp, hdr); n != 1 {
4933 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4934 }
4935 }
4936
4937
4938 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
4939 setParallel(t)
4940 defer afterTest(t)
4941 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4942 w.Header().Set("Transfer-Encoding", "gzip")
4943 gz := gzip.NewWriter(w)
4944 gz.Write([]byte("hello"))
4945 gz.Close()
4946 }))
4947 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4948 for _, v := range []string{"gzip", "chunked"} {
4949 hdr := "Transfer-Encoding: " + v
4950 if n := strings.Count(resp, hdr); n != 1 {
4951 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4952 }
4953 }
4954 }
4955
4956 func BenchmarkClientServer(b *testing.B) {
4957 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
4958 }
4959 func benchmarkClientServer(b *testing.B, mode testMode) {
4960 b.ReportAllocs()
4961 b.StopTimer()
4962 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
4963 fmt.Fprintf(rw, "Hello world.\n")
4964 })).ts
4965 b.StartTimer()
4966
4967 c := ts.Client()
4968 for i := 0; i < b.N; i++ {
4969 res, err := c.Get(ts.URL)
4970 if err != nil {
4971 b.Fatal("Get:", err)
4972 }
4973 all, err := io.ReadAll(res.Body)
4974 res.Body.Close()
4975 if err != nil {
4976 b.Fatal("ReadAll:", err)
4977 }
4978 body := string(all)
4979 if body != "Hello world.\n" {
4980 b.Fatal("Got body:", body)
4981 }
4982 }
4983
4984 b.StopTimer()
4985 }
4986
4987 func BenchmarkClientServerParallel(b *testing.B) {
4988 for _, parallelism := range []int{4, 64} {
4989 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
4990 run(b, func(b *testing.B, mode testMode) {
4991 benchmarkClientServerParallel(b, parallelism, mode)
4992 }, []testMode{http1Mode, https1Mode, http2Mode})
4993 })
4994 }
4995 }
4996
4997 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
4998 b.ReportAllocs()
4999 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5000 fmt.Fprintf(rw, "Hello world.\n")
5001 })).ts
5002 b.ResetTimer()
5003 b.SetParallelism(parallelism)
5004 b.RunParallel(func(pb *testing.PB) {
5005 c := ts.Client()
5006 for pb.Next() {
5007 res, err := c.Get(ts.URL)
5008 if err != nil {
5009 b.Logf("Get: %v", err)
5010 continue
5011 }
5012 all, err := io.ReadAll(res.Body)
5013 res.Body.Close()
5014 if err != nil {
5015 b.Logf("ReadAll: %v", err)
5016 continue
5017 }
5018 body := string(all)
5019 if body != "Hello world.\n" {
5020 panic("Got body: " + body)
5021 }
5022 }
5023 })
5024 }
5025
5026
5027
5028
5029
5030
5031
5032
5033
5034
5035 func BenchmarkServer(b *testing.B) {
5036 b.ReportAllocs()
5037
5038 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5039 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5040 if err != nil {
5041 panic(err)
5042 }
5043 for i := 0; i < n; i++ {
5044 res, err := Get(url)
5045 if err != nil {
5046 log.Panicf("Get: %v", err)
5047 }
5048 all, err := io.ReadAll(res.Body)
5049 res.Body.Close()
5050 if err != nil {
5051 log.Panicf("ReadAll: %v", err)
5052 }
5053 body := string(all)
5054 if body != "Hello world.\n" {
5055 log.Panicf("Got body: %q", body)
5056 }
5057 }
5058 os.Exit(0)
5059 return
5060 }
5061
5062 var res = []byte("Hello world.\n")
5063 b.StopTimer()
5064 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5065 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5066 rw.Write(res)
5067 }))
5068 defer ts.Close()
5069 b.StartTimer()
5070
5071 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5072 cmd.Env = append([]string{
5073 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5074 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5075 }, os.Environ()...)
5076 out, err := cmd.CombinedOutput()
5077 if err != nil {
5078 b.Errorf("Test failure: %v, with output: %s", err, out)
5079 }
5080 }
5081
5082
5083 func getNoBody(urlStr string) (*Response, error) {
5084 res, err := Get(urlStr)
5085 if err != nil {
5086 return nil, err
5087 }
5088 res.Body.Close()
5089 return res, nil
5090 }
5091
5092
5093
5094 func BenchmarkClient(b *testing.B) {
5095 b.ReportAllocs()
5096 b.StopTimer()
5097 defer afterTest(b)
5098
5099 var data = []byte("Hello world.\n")
5100 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5101
5102 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5103 if port == "" {
5104 port = "0"
5105 }
5106 ln, err := net.Listen("tcp", "localhost:"+port)
5107 if err != nil {
5108 fmt.Fprintln(os.Stderr, err.Error())
5109 os.Exit(1)
5110 }
5111 fmt.Println(ln.Addr().String())
5112 HandleFunc("/", func(w ResponseWriter, r *Request) {
5113 r.ParseForm()
5114 if r.Form.Get("stop") != "" {
5115 os.Exit(0)
5116 }
5117 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5118 w.Write(data)
5119 })
5120 var srv Server
5121 log.Fatal(srv.Serve(ln))
5122 }
5123
5124
5125 ctx, cancel := context.WithCancel(context.Background())
5126 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5127 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5128 cmd.Stderr = os.Stderr
5129 stdout, err := cmd.StdoutPipe()
5130 if err != nil {
5131 b.Fatal(err)
5132 }
5133 if err := cmd.Start(); err != nil {
5134 b.Fatalf("subprocess failed to start: %v", err)
5135 }
5136
5137 done := make(chan error, 1)
5138 go func() {
5139 done <- cmd.Wait()
5140 close(done)
5141 }()
5142 defer func() {
5143 cancel()
5144 <-done
5145 }()
5146
5147
5148
5149 bs := bufio.NewScanner(stdout)
5150 if !bs.Scan() {
5151 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5152 }
5153 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5154 if _, err := getNoBody(url); err != nil {
5155 b.Fatalf("initial probe of child process failed: %v", err)
5156 }
5157
5158
5159 b.StartTimer()
5160 for i := 0; i < b.N; i++ {
5161 res, err := Get(url)
5162 if err != nil {
5163 b.Fatalf("Get: %v", err)
5164 }
5165 body, err := io.ReadAll(res.Body)
5166 res.Body.Close()
5167 if err != nil {
5168 b.Fatalf("ReadAll: %v", err)
5169 }
5170 if !bytes.Equal(body, data) {
5171 b.Fatalf("Got body: %q", body)
5172 }
5173 }
5174 b.StopTimer()
5175
5176
5177 getNoBody(url + "?stop=yes")
5178 if err := <-done; err != nil {
5179 b.Fatalf("subprocess failed: %v", err)
5180 }
5181 }
5182
5183 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5184 b.ReportAllocs()
5185 req := reqBytes(`GET / HTTP/1.0
5186 Host: golang.org
5187 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5188 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5189 Accept-Encoding: gzip,deflate,sdch
5190 Accept-Language: en-US,en;q=0.8
5191 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5192 `)
5193 res := []byte("Hello world!\n")
5194
5195 conn := newTestConn()
5196 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5197 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5198 rw.Write(res)
5199 })
5200 ln := new(oneConnListener)
5201 for i := 0; i < b.N; i++ {
5202 conn.readBuf.Reset()
5203 conn.writeBuf.Reset()
5204 conn.readBuf.Write(req)
5205 ln.conn = conn
5206 Serve(ln, handler)
5207 <-conn.closec
5208 }
5209 }
5210
5211
5212 type repeatReader struct {
5213 content []byte
5214 count int
5215 off int
5216 }
5217
5218 func (r *repeatReader) Read(p []byte) (n int, err error) {
5219 if r.count <= 0 {
5220 return 0, io.EOF
5221 }
5222 n = copy(p, r.content[r.off:])
5223 r.off += n
5224 if r.off == len(r.content) {
5225 r.count--
5226 r.off = 0
5227 }
5228 return
5229 }
5230
5231 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5232 b.ReportAllocs()
5233
5234 req := reqBytes(`GET / HTTP/1.1
5235 Host: golang.org
5236 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5237 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5238 Accept-Encoding: gzip,deflate,sdch
5239 Accept-Language: en-US,en;q=0.8
5240 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5241 `)
5242 res := []byte("Hello world!\n")
5243
5244 conn := &rwTestConn{
5245 Reader: &repeatReader{content: req, count: b.N},
5246 Writer: io.Discard,
5247 closec: make(chan bool, 1),
5248 }
5249 handled := 0
5250 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5251 handled++
5252 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5253 rw.Write(res)
5254 })
5255 ln := &oneConnListener{conn: conn}
5256 go Serve(ln, handler)
5257 <-conn.closec
5258 if b.N != handled {
5259 b.Errorf("b.N=%d but handled %d", b.N, handled)
5260 }
5261 }
5262
5263
5264
5265 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5266 b.ReportAllocs()
5267
5268 req := reqBytes(`GET / HTTP/1.1
5269 Host: golang.org
5270 `)
5271 res := []byte("Hello world!\n")
5272
5273 conn := &rwTestConn{
5274 Reader: &repeatReader{content: req, count: b.N},
5275 Writer: io.Discard,
5276 closec: make(chan bool, 1),
5277 }
5278 handled := 0
5279 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5280 handled++
5281 rw.Write(res)
5282 })
5283 ln := &oneConnListener{conn: conn}
5284 go Serve(ln, handler)
5285 <-conn.closec
5286 if b.N != handled {
5287 b.Errorf("b.N=%d but handled %d", b.N, handled)
5288 }
5289 }
5290
5291 const someResponse = "<html>some response</html>"
5292
5293
5294 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5295
5296
5297 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5298 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5299 w.Header().Set("Content-Type", "text/html")
5300 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5301 w.Write(response)
5302 }))
5303 }
5304
5305
5306 func BenchmarkServerHandlerNoLen(b *testing.B) {
5307 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5308 w.Header().Set("Content-Type", "text/html")
5309 w.Write(response)
5310 }))
5311 }
5312
5313
5314 func BenchmarkServerHandlerNoType(b *testing.B) {
5315 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5316 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5317 w.Write(response)
5318 }))
5319 }
5320
5321
5322 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5323 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5324 w.Write(response)
5325 }))
5326 }
5327
5328 func benchmarkHandler(b *testing.B, h Handler) {
5329 b.ReportAllocs()
5330 req := reqBytes(`GET / HTTP/1.1
5331 Host: golang.org
5332 `)
5333 conn := &rwTestConn{
5334 Reader: &repeatReader{content: req, count: b.N},
5335 Writer: io.Discard,
5336 closec: make(chan bool, 1),
5337 }
5338 handled := 0
5339 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5340 handled++
5341 h.ServeHTTP(rw, r)
5342 })
5343 ln := &oneConnListener{conn: conn}
5344 go Serve(ln, handler)
5345 <-conn.closec
5346 if b.N != handled {
5347 b.Errorf("b.N=%d but handled %d", b.N, handled)
5348 }
5349 }
5350
5351 func BenchmarkServerHijack(b *testing.B) {
5352 b.ReportAllocs()
5353 req := reqBytes(`GET / HTTP/1.1
5354 Host: golang.org
5355 `)
5356 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5357 conn, _, err := w.(Hijacker).Hijack()
5358 if err != nil {
5359 panic(err)
5360 }
5361 conn.Close()
5362 })
5363 conn := &rwTestConn{
5364 Writer: io.Discard,
5365 closec: make(chan bool, 1),
5366 }
5367 ln := &oneConnListener{conn: conn}
5368 for i := 0; i < b.N; i++ {
5369 conn.Reader = bytes.NewReader(req)
5370 ln.conn = conn
5371 Serve(ln, h)
5372 <-conn.closec
5373 }
5374 }
5375
5376 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5377 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5378 b.ReportAllocs()
5379 b.StopTimer()
5380 sawClose := make(chan bool)
5381 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5382 <-rw.(CloseNotifier).CloseNotify()
5383 sawClose <- true
5384 })).ts
5385 b.StartTimer()
5386 for i := 0; i < b.N; i++ {
5387 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5388 if err != nil {
5389 b.Fatalf("error dialing: %v", err)
5390 }
5391 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5392 if err != nil {
5393 b.Fatal(err)
5394 }
5395 conn.Close()
5396 <-sawClose
5397 }
5398 b.StopTimer()
5399 }
5400
5401
5402 func TestConcurrentServerServe(t *testing.T) {
5403 setParallel(t)
5404 for i := 0; i < 100; i++ {
5405 ln1 := &oneConnListener{conn: nil}
5406 ln2 := &oneConnListener{conn: nil}
5407 srv := Server{}
5408 go func() { srv.Serve(ln1) }()
5409 go func() { srv.Serve(ln2) }()
5410 }
5411 }
5412
5413 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5414 func testServerIdleTimeout(t *testing.T, mode testMode) {
5415 if testing.Short() {
5416 t.Skip("skipping in short mode")
5417 }
5418 runTimeSensitiveTest(t, []time.Duration{
5419 10 * time.Millisecond,
5420 100 * time.Millisecond,
5421 1 * time.Second,
5422 10 * time.Second,
5423 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5424 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5425 io.Copy(io.Discard, r.Body)
5426 io.WriteString(w, r.RemoteAddr)
5427 }), func(ts *httptest.Server) {
5428 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5429 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5430 })
5431 defer cst.close()
5432 ts := cst.ts
5433 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5434 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5435 c := ts.Client()
5436
5437 get := func() (string, error) {
5438 res, err := c.Get(ts.URL)
5439 if err != nil {
5440 return "", err
5441 }
5442 defer res.Body.Close()
5443 slurp, err := io.ReadAll(res.Body)
5444 if err != nil {
5445
5446
5447
5448 t.Fatal(err)
5449 }
5450 return string(slurp), nil
5451 }
5452
5453 a1, err := get()
5454 if err != nil {
5455 return err
5456 }
5457 a2, err := get()
5458 if err != nil {
5459 return err
5460 }
5461 if a1 != a2 {
5462 return fmt.Errorf("did requests on different connections")
5463 }
5464 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5465 a3, err := get()
5466 if err != nil {
5467 return err
5468 }
5469 if a2 == a3 {
5470 return fmt.Errorf("request three unexpectedly on same connection")
5471 }
5472
5473
5474 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5475 if err != nil {
5476 return err
5477 }
5478 defer conn.Close()
5479 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5480 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5481 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5482 return fmt.Errorf("copy byte succeeded; want err")
5483 }
5484
5485 return nil
5486 })
5487 }
5488
5489 func get(t *testing.T, c *Client, url string) string {
5490 res, err := c.Get(url)
5491 if err != nil {
5492 t.Fatal(err)
5493 }
5494 defer res.Body.Close()
5495 slurp, err := io.ReadAll(res.Body)
5496 if err != nil {
5497 t.Fatal(err)
5498 }
5499 return string(slurp)
5500 }
5501
5502
5503
5504 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5505 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5506 }
5507 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5508 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5509 io.WriteString(w, r.RemoteAddr)
5510 })).ts
5511
5512 c := ts.Client()
5513 tr := c.Transport.(*Transport)
5514
5515 get := func() string { return get(t, c, ts.URL) }
5516
5517 a1, a2 := get(), get()
5518 if a1 == a2 {
5519 t.Logf("made two requests from a single conn %q (as expected)", a1)
5520 } else {
5521 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5522 }
5523
5524
5525
5526
5527
5528 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5529 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5530 }
5531
5532
5533 ts.Config.SetKeepAlivesEnabled(false)
5534
5535 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5536 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5537 if d > 0 {
5538 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5539 }
5540 return false
5541 }
5542 return true
5543 })
5544
5545
5546
5547
5548 }
5549
5550 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5551 func testServerShutdown(t *testing.T, mode testMode) {
5552 var cst *clientServerTest
5553
5554 var once sync.Once
5555 statesRes := make(chan map[ConnState]int, 1)
5556 shutdownRes := make(chan error, 1)
5557 gotOnShutdown := make(chan struct{})
5558 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5559 first := false
5560 once.Do(func() {
5561 statesRes <- cst.ts.Config.ExportAllConnsByState()
5562 go func() {
5563 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5564 }()
5565 first = true
5566 })
5567
5568 if first {
5569
5570
5571
5572 <-gotOnShutdown
5573
5574
5575 for !t.Failed() {
5576 res, err := cst.c.Get(cst.ts.URL)
5577 if err != nil {
5578 break
5579 }
5580 out, _ := io.ReadAll(res.Body)
5581 res.Body.Close()
5582 if mode == http2Mode {
5583 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5584 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5585 continue
5586 }
5587 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5588 }
5589 }
5590
5591 io.WriteString(w, r.RemoteAddr)
5592 })
5593
5594 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5595 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5596 })
5597
5598 out := get(t, cst.c, cst.ts.URL)
5599 t.Logf("%v: %q", cst.ts.URL, out)
5600
5601 if err := <-shutdownRes; err != nil {
5602 t.Fatalf("Shutdown: %v", err)
5603 }
5604 <-gotOnShutdown
5605
5606 if states := <-statesRes; states[StateActive] != 1 {
5607 t.Errorf("connection in wrong state, %v", states)
5608 }
5609 }
5610
5611 func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
5612 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5613 if testing.Short() {
5614 t.Skip("test takes 5-6 seconds; skipping in short mode")
5615 }
5616
5617 var connAccepted sync.WaitGroup
5618 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5619
5620 }), func(ts *httptest.Server) {
5621 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5622 if state == StateNew {
5623 connAccepted.Done()
5624 }
5625 }
5626 }).ts
5627
5628
5629 connAccepted.Add(1)
5630 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5631 if err != nil {
5632 t.Fatal(err)
5633 }
5634 defer c.Close()
5635
5636
5637
5638
5639
5640 connAccepted.Wait()
5641
5642 shutdownRes := make(chan error, 1)
5643 go func() {
5644 shutdownRes <- ts.Config.Shutdown(context.Background())
5645 }()
5646 readRes := make(chan error, 1)
5647 go func() {
5648 _, err := c.Read([]byte{0})
5649 readRes <- err
5650 }()
5651
5652
5653
5654
5655 const expectTimeout = 5 * time.Second
5656
5657 t0 := time.Now()
5658 select {
5659 case got := <-shutdownRes:
5660 d := time.Since(t0)
5661 if got != nil {
5662 t.Fatalf("shutdown error after %v: %v", d, err)
5663 }
5664 if d < expectTimeout/2 {
5665 t.Errorf("shutdown too soon after %v", d)
5666 }
5667 case <-time.After(expectTimeout * 3 / 2):
5668 t.Fatalf("timeout waiting for shutdown")
5669 }
5670
5671
5672
5673 if err := <-readRes; err == nil {
5674 t.Error("expected error from Read")
5675 }
5676 }
5677
5678
5679 func TestServerCloseDeadlock(t *testing.T) {
5680 var s Server
5681 s.Close()
5682 s.Close()
5683 }
5684
5685
5686
5687 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5688 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5689 if mode == http2Mode {
5690 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5691 defer restore()
5692 }
5693
5694 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5695 defer cst.close()
5696 srv := cst.ts.Config
5697 srv.SetKeepAlivesEnabled(false)
5698 for try := 0; try < 2; try++ {
5699 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5700 if !srv.ExportAllConnsIdle() {
5701 if d > 0 {
5702 t.Logf("test server still has active conns after %v", d)
5703 }
5704 return false
5705 }
5706 return true
5707 })
5708 conns := 0
5709 var info httptrace.GotConnInfo
5710 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5711 GotConn: func(v httptrace.GotConnInfo) {
5712 conns++
5713 info = v
5714 },
5715 })
5716 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5717 if err != nil {
5718 t.Fatal(err)
5719 }
5720 res, err := cst.c.Do(req)
5721 if err != nil {
5722 t.Fatal(err)
5723 }
5724 res.Body.Close()
5725 if conns != 1 {
5726 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5727 }
5728 if info.Reused || info.WasIdle {
5729 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5730 }
5731 }
5732 }
5733
5734
5735
5736
5737 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5738 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5739 runTimeSensitiveTest(t, []time.Duration{
5740 10 * time.Millisecond,
5741 50 * time.Millisecond,
5742 250 * time.Millisecond,
5743 time.Second,
5744 2 * time.Second,
5745 }, func(t *testing.T, timeout time.Duration) error {
5746 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5747 select {
5748 case <-time.After(2 * timeout):
5749 fmt.Fprint(w, "ok")
5750 case <-r.Context().Done():
5751 fmt.Fprint(w, r.Context().Err())
5752 }
5753 }), func(ts *httptest.Server) {
5754 ts.Config.ReadTimeout = timeout
5755 })
5756 defer cst.close()
5757 ts := cst.ts
5758
5759 c := ts.Client()
5760
5761 res, err := c.Get(ts.URL)
5762 if err != nil {
5763 return fmt.Errorf("Get: %v", err)
5764 }
5765 slurp, err := io.ReadAll(res.Body)
5766 res.Body.Close()
5767 if err != nil {
5768 return fmt.Errorf("Body ReadAll: %v", err)
5769 }
5770 if string(slurp) != "ok" {
5771 return fmt.Errorf("got: %q, want ok", slurp)
5772 }
5773 return nil
5774 })
5775 }
5776
5777
5778
5779
5780 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5781 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5782 }
5783 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5784 runTimeSensitiveTest(t, []time.Duration{
5785 10 * time.Millisecond,
5786 50 * time.Millisecond,
5787 250 * time.Millisecond,
5788 time.Second,
5789 2 * time.Second,
5790 }, func(t *testing.T, timeout time.Duration) error {
5791 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5792 ts.Config.ReadHeaderTimeout = timeout
5793 ts.Config.IdleTimeout = 0
5794 })
5795 defer cst.close()
5796 ts := cst.ts
5797
5798
5799
5800 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5801 if err != nil {
5802 t.Fatalf("dial failed: %v", err)
5803 }
5804 br := bufio.NewReader(conn)
5805 defer conn.Close()
5806
5807 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5808 return fmt.Errorf("writing first request failed: %v", err)
5809 }
5810
5811 if _, err := ReadResponse(br, nil); err != nil {
5812 return fmt.Errorf("first response (before timeout) failed: %v", err)
5813 }
5814
5815
5816
5817 time.Sleep(timeout * 3 / 2)
5818
5819 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5820 return fmt.Errorf("writing second request failed: %v", err)
5821 }
5822
5823 if _, err := ReadResponse(br, nil); err != nil {
5824 return fmt.Errorf("second response (after timeout) failed: %v", err)
5825 }
5826
5827 return nil
5828 })
5829 }
5830
5831
5832
5833 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
5834 for i, d := range durations {
5835 err := test(t, d)
5836 if err == nil {
5837 return
5838 }
5839 if i == len(durations)-1 || t.Failed() {
5840 t.Fatalf("failed with duration %v: %v", d, err)
5841 }
5842 t.Logf("retrying after error with duration %v: %v", d, err)
5843 }
5844 }
5845
5846
5847
5848 func TestServerDuplicateBackgroundRead(t *testing.T) {
5849 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
5850 }
5851 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
5852 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
5853 testenv.SkipFlaky(t, 24826)
5854 }
5855
5856 goroutines := 5
5857 requests := 2000
5858 if testing.Short() {
5859 goroutines = 3
5860 requests = 100
5861 }
5862
5863 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
5864
5865 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
5866
5867 var wg sync.WaitGroup
5868 for i := 0; i < goroutines; i++ {
5869 wg.Add(1)
5870 go func() {
5871 defer wg.Done()
5872 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
5873 if err != nil {
5874 t.Error(err)
5875 return
5876 }
5877 defer cn.Close()
5878
5879 wg.Add(1)
5880 go func() {
5881 defer wg.Done()
5882 io.Copy(io.Discard, cn)
5883 }()
5884
5885 for j := 0; j < requests; j++ {
5886 if t.Failed() {
5887 return
5888 }
5889 _, err := cn.Write(reqBytes)
5890 if err != nil {
5891 t.Error(err)
5892 return
5893 }
5894 }
5895 }()
5896 }
5897 wg.Wait()
5898 }
5899
5900
5901
5902
5903
5904
5905 func TestServerHijackGetsBackgroundByte(t *testing.T) {
5906 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
5907 }
5908 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
5909 if runtime.GOOS == "plan9" {
5910 t.Skip("skipping test; see https://golang.org/issue/18657")
5911 }
5912 done := make(chan struct{})
5913 inHandler := make(chan bool, 1)
5914 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5915 defer close(done)
5916
5917
5918 inHandler <- true
5919
5920 conn, buf, err := w.(Hijacker).Hijack()
5921 if err != nil {
5922 t.Error(err)
5923 return
5924 }
5925 defer conn.Close()
5926
5927 peek, err := buf.Reader.Peek(3)
5928 if string(peek) != "foo" || err != nil {
5929 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
5930 }
5931
5932 select {
5933 case <-r.Context().Done():
5934 t.Error("context unexpectedly canceled")
5935 default:
5936 }
5937 })).ts
5938
5939 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5940 if err != nil {
5941 t.Fatal(err)
5942 }
5943 defer cn.Close()
5944 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5945 t.Fatal(err)
5946 }
5947 <-inHandler
5948 if _, err := cn.Write([]byte("foo")); err != nil {
5949 t.Fatal(err)
5950 }
5951
5952 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5953 t.Fatal(err)
5954 }
5955 <-done
5956 }
5957
5958
5959
5960
5961 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
5962 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
5963 }
5964 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
5965 if runtime.GOOS == "plan9" {
5966 t.Skip("skipping test; see https://golang.org/issue/18657")
5967 }
5968 done := make(chan struct{})
5969 const size = 8 << 10
5970 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5971 defer close(done)
5972
5973 conn, buf, err := w.(Hijacker).Hijack()
5974 if err != nil {
5975 t.Error(err)
5976 return
5977 }
5978 defer conn.Close()
5979 slurp, err := io.ReadAll(buf.Reader)
5980 if err != nil {
5981 t.Errorf("Copy: %v", err)
5982 }
5983 allX := true
5984 for _, v := range slurp {
5985 if v != 'x' {
5986 allX = false
5987 }
5988 }
5989 if len(slurp) != size {
5990 t.Errorf("read %d; want %d", len(slurp), size)
5991 } else if !allX {
5992 t.Errorf("read %q; want %d 'x'", slurp, size)
5993 }
5994 })).ts
5995
5996 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5997 if err != nil {
5998 t.Fatal(err)
5999 }
6000 defer cn.Close()
6001 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6002 strings.Repeat("x", size)); err != nil {
6003 t.Fatal(err)
6004 }
6005 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6006 t.Fatal(err)
6007 }
6008
6009 <-done
6010 }
6011
6012
6013 func TestServerValidatesMethod(t *testing.T) {
6014 tests := []struct {
6015 method string
6016 want int
6017 }{
6018 {"GET", 200},
6019 {"GE(T", 400},
6020 }
6021 for _, tt := range tests {
6022 conn := newTestConn()
6023 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6024
6025 ln := &oneConnListener{conn}
6026 go Serve(ln, serve(200))
6027 <-conn.closec
6028 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6029 if err != nil {
6030 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6031 continue
6032 }
6033 if res.StatusCode != tt.want {
6034 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6035 }
6036 }
6037 }
6038
6039
6040 type eofListenerNotComparable []int
6041
6042 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6043 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6044 func (eofListenerNotComparable) Close() error { return nil }
6045
6046
6047 func TestServerListenNotComparableListener(t *testing.T) {
6048 var s Server
6049 s.Serve(make(eofListenerNotComparable, 1))
6050 }
6051
6052
6053 type countCloseListener struct {
6054 net.Listener
6055 closes int32
6056 }
6057
6058 func (p *countCloseListener) Close() error {
6059 var err error
6060 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6061 err = p.Listener.Close()
6062 }
6063 return err
6064 }
6065
6066
6067 func TestServerCloseListenerOnce(t *testing.T) {
6068 setParallel(t)
6069 defer afterTest(t)
6070
6071 ln := newLocalListener(t)
6072 defer ln.Close()
6073
6074 cl := &countCloseListener{Listener: ln}
6075 server := &Server{}
6076 sdone := make(chan bool, 1)
6077
6078 go func() {
6079 server.Serve(cl)
6080 sdone <- true
6081 }()
6082 time.Sleep(10 * time.Millisecond)
6083 server.Shutdown(context.Background())
6084 ln.Close()
6085 <-sdone
6086
6087 nclose := atomic.LoadInt32(&cl.closes)
6088 if nclose != 1 {
6089 t.Errorf("Close calls = %v; want 1", nclose)
6090 }
6091 }
6092
6093
6094 func TestServerShutdownThenServe(t *testing.T) {
6095 var srv Server
6096 cl := &countCloseListener{Listener: nil}
6097 srv.Shutdown(context.Background())
6098 got := srv.Serve(cl)
6099 if got != ErrServerClosed {
6100 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6101 }
6102 nclose := atomic.LoadInt32(&cl.closes)
6103 if nclose != 1 {
6104 t.Errorf("Close calls = %v; want 1", nclose)
6105 }
6106 }
6107
6108
6109 func TestStripPortFromHost(t *testing.T) {
6110 mux := NewServeMux()
6111
6112 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6113 fmt.Fprintf(w, "OK")
6114 })
6115 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6116 fmt.Fprintf(w, "uh-oh!")
6117 })
6118
6119 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6120 rw := httptest.NewRecorder()
6121
6122 mux.ServeHTTP(rw, req)
6123
6124 response := rw.Body.String()
6125 if response != "OK" {
6126 t.Errorf("Response gotten was %q", response)
6127 }
6128 }
6129
6130 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6131 func testServerContexts(t *testing.T, mode testMode) {
6132 type baseKey struct{}
6133 type connKey struct{}
6134 ch := make(chan context.Context, 1)
6135 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6136 ch <- r.Context()
6137 }), func(ts *httptest.Server) {
6138 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6139 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6140 t.Errorf("unexpected onceClose listener type %T", ln)
6141 }
6142 return context.WithValue(context.Background(), baseKey{}, "base")
6143 }
6144 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6145 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6146 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6147 }
6148 return context.WithValue(ctx, connKey{}, "conn")
6149 }
6150 }).ts
6151 res, err := ts.Client().Get(ts.URL)
6152 if err != nil {
6153 t.Fatal(err)
6154 }
6155 res.Body.Close()
6156 ctx := <-ch
6157 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6158 t.Errorf("base context key = %#v; want %q", got, want)
6159 }
6160 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6161 t.Errorf("conn context key = %#v; want %q", got, want)
6162 }
6163 }
6164
6165
6166 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6167 run(t, testConnContextNotModifyingAllContexts)
6168 }
6169 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6170 type connKey struct{}
6171 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6172 rw.Header().Set("Connection", "close")
6173 }), func(ts *httptest.Server) {
6174 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6175 if got := ctx.Value(connKey{}); got != nil {
6176 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6177 }
6178 return context.WithValue(ctx, connKey{}, "conn")
6179 }
6180 }).ts
6181
6182 var res *Response
6183 var err error
6184
6185 res, err = ts.Client().Get(ts.URL)
6186 if err != nil {
6187 t.Fatal(err)
6188 }
6189 res.Body.Close()
6190
6191 res, err = ts.Client().Get(ts.URL)
6192 if err != nil {
6193 t.Fatal(err)
6194 }
6195 res.Body.Close()
6196 }
6197
6198
6199
6200 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6201 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6202 }
6203 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6204 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6205 w.Write([]byte("Hello, World!"))
6206 })).ts
6207
6208 serverURL, err := url.Parse(cst.URL)
6209 if err != nil {
6210 t.Fatalf("Failed to parse server URL: %v", err)
6211 }
6212
6213 unsupportedTEs := []string{
6214 "fugazi",
6215 "foo-bar",
6216 "unknown",
6217 `" chunked"`,
6218 }
6219
6220 for _, badTE := range unsupportedTEs {
6221 http1ReqBody := fmt.Sprintf(""+
6222 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6223 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6224
6225 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6226 if err != nil {
6227 t.Errorf("%q. unexpected error: %v", badTE, err)
6228 continue
6229 }
6230
6231 wantBody := fmt.Sprintf("" +
6232 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6233 "Connection: close\r\n\r\nUnsupported transfer encoding")
6234
6235 if string(gotBody) != wantBody {
6236 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6237 }
6238 }
6239 }
6240
6241
6242 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6243 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6244 type setting struct {
6245 name string
6246 body []byte
6247
6248
6249
6250
6251 contentEncoding any
6252 wantContentType string
6253 }
6254
6255 settings := []*setting{
6256 {
6257 name: "gzip content-encoding, gzipped",
6258 contentEncoding: "application/gzip",
6259 wantContentType: "",
6260 body: func() []byte {
6261 buf := new(bytes.Buffer)
6262 gzw := gzip.NewWriter(buf)
6263 gzw.Write([]byte("doctype html><p>Hello</p>"))
6264 gzw.Close()
6265 return buf.Bytes()
6266 }(),
6267 },
6268 {
6269 name: "zlib content-encoding, zlibbed",
6270 contentEncoding: "application/zlib",
6271 wantContentType: "",
6272 body: func() []byte {
6273 buf := new(bytes.Buffer)
6274 zw := zlib.NewWriter(buf)
6275 zw.Write([]byte("doctype html><p>Hello</p>"))
6276 zw.Close()
6277 return buf.Bytes()
6278 }(),
6279 },
6280 {
6281 name: "no content-encoding",
6282 wantContentType: "application/x-gzip",
6283 body: func() []byte {
6284 buf := new(bytes.Buffer)
6285 gzw := gzip.NewWriter(buf)
6286 gzw.Write([]byte("doctype html><p>Hello</p>"))
6287 gzw.Close()
6288 return buf.Bytes()
6289 }(),
6290 },
6291 {
6292 name: "phony content-encoding",
6293 contentEncoding: "foo/bar",
6294 body: []byte("doctype html><p>Hello</p>"),
6295 },
6296 {
6297 name: "empty but set content-encoding",
6298 contentEncoding: "",
6299 wantContentType: "audio/mpeg",
6300 body: []byte("ID3"),
6301 },
6302 }
6303
6304 for _, tt := range settings {
6305 t.Run(tt.name, func(t *testing.T) {
6306 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6307 if tt.contentEncoding != nil {
6308 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6309 }
6310 rw.Write(tt.body)
6311 }))
6312
6313 res, err := cst.c.Get(cst.ts.URL)
6314 if err != nil {
6315 t.Fatalf("Failed to fetch URL: %v", err)
6316 }
6317 defer res.Body.Close()
6318
6319 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6320 if w != nil {
6321 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6322 } else if g != "" {
6323 t.Errorf("Unexpected Content-Encoding %q", g)
6324 }
6325 }
6326
6327 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6328 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6329 }
6330 })
6331 }
6332 }
6333
6334
6335
6336 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6337 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6338 }
6339 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6340 if testing.Short() {
6341 t.Skip("skipping in short mode")
6342 }
6343
6344 pc, curFile, _, _ := runtime.Caller(0)
6345 curFileBaseName := filepath.Base(curFile)
6346 testFuncName := runtime.FuncForPC(pc).Name()
6347
6348 timeoutMsg := "timed out here!"
6349
6350 tests := []struct {
6351 name string
6352 mustTimeout bool
6353 wantResp string
6354 }{
6355 {
6356 name: "return before timeout",
6357 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6358 },
6359 {
6360 name: "return after timeout",
6361 mustTimeout: true,
6362 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6363 len(timeoutMsg), timeoutMsg),
6364 },
6365 }
6366
6367 for _, tt := range tests {
6368 tt := tt
6369 t.Run(tt.name, func(t *testing.T) {
6370 exitHandler := make(chan bool, 1)
6371 defer close(exitHandler)
6372 lastLine := make(chan int, 1)
6373
6374 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6375 w.WriteHeader(404)
6376 w.WriteHeader(404)
6377 w.WriteHeader(404)
6378 w.WriteHeader(404)
6379 _, _, line, _ := runtime.Caller(0)
6380 lastLine <- line
6381 <-exitHandler
6382 })
6383
6384 if !tt.mustTimeout {
6385 exitHandler <- true
6386 }
6387
6388 logBuf := new(strings.Builder)
6389 srvLog := log.New(logBuf, "", 0)
6390
6391 dur := 20 * time.Millisecond
6392 if !tt.mustTimeout {
6393
6394 dur = 10 * time.Second
6395 }
6396 th := TimeoutHandler(sh, dur, timeoutMsg)
6397 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6398 defer cst.close()
6399
6400 res, err := cst.c.Get(cst.ts.URL)
6401 if err != nil {
6402 t.Fatalf("Unexpected error: %v", err)
6403 }
6404
6405
6406
6407 res.Header.Del("Date")
6408 res.Header.Del("Content-Type")
6409
6410
6411 blob, _ := httputil.DumpResponse(res, true)
6412 if g, w := string(blob), tt.wantResp; g != w {
6413 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6414 }
6415
6416
6417
6418 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6419 if g, w := len(logEntries), 3; g != w {
6420 blob, _ := json.MarshalIndent(logEntries, "", " ")
6421 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6422 }
6423
6424 lastSpuriousLine := <-lastLine
6425 firstSpuriousLine := lastSpuriousLine - 3
6426
6427
6428 for i, logEntry := range logEntries {
6429 wantLine := firstSpuriousLine + i
6430 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6431 testFuncName, curFileBaseName, wantLine)
6432 re := regexp.MustCompile(pat)
6433 if !re.MatchString(logEntry) {
6434 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6435 }
6436 }
6437 })
6438 }
6439 }
6440
6441
6442
6443
6444 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6445 conn, err := net.Dial("tcp", host)
6446 if err != nil {
6447 return nil, err
6448 }
6449 defer conn.Close()
6450
6451 if _, err := conn.Write(http1ReqBody); err != nil {
6452 return nil, err
6453 }
6454 return io.ReadAll(conn)
6455 }
6456
6457 func BenchmarkResponseStatusLine(b *testing.B) {
6458 b.ReportAllocs()
6459 b.RunParallel(func(pb *testing.PB) {
6460 bw := bufio.NewWriter(io.Discard)
6461 var buf3 [3]byte
6462 for pb.Next() {
6463 Export_writeStatusLine(bw, true, 200, buf3[:])
6464 }
6465 })
6466 }
6467
6468 func TestDisableKeepAliveUpgrade(t *testing.T) {
6469 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6470 }
6471 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6472 if testing.Short() {
6473 t.Skip("skipping in short mode")
6474 }
6475
6476 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6477 w.Header().Set("Connection", "Upgrade")
6478 w.Header().Set("Upgrade", "someProto")
6479 w.WriteHeader(StatusSwitchingProtocols)
6480 c, buf, err := w.(Hijacker).Hijack()
6481 if err != nil {
6482 return
6483 }
6484 defer c.Close()
6485
6486
6487
6488 io.Copy(c, buf)
6489 }), func(ts *httptest.Server) {
6490 ts.Config.SetKeepAlivesEnabled(false)
6491 }).ts
6492
6493 cl := s.Client()
6494 cl.Transport.(*Transport).DisableKeepAlives = true
6495
6496 resp, err := cl.Get(s.URL)
6497 if err != nil {
6498 t.Fatalf("failed to perform request: %v", err)
6499 }
6500 defer resp.Body.Close()
6501
6502 if resp.StatusCode != StatusSwitchingProtocols {
6503 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6504 }
6505
6506 rwc, ok := resp.Body.(io.ReadWriteCloser)
6507 if !ok {
6508 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6509 }
6510
6511 _, err = rwc.Write([]byte("hello"))
6512 if err != nil {
6513 t.Fatalf("failed to write to body: %v", err)
6514 }
6515
6516 b := make([]byte, 5)
6517 _, err = io.ReadFull(rwc, b)
6518 if err != nil {
6519 t.Fatalf("failed to read from body: %v", err)
6520 }
6521
6522 if string(b) != "hello" {
6523 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6524 }
6525 }
6526
6527 type tlogWriter struct{ t *testing.T }
6528
6529 func (w tlogWriter) Write(p []byte) (int, error) {
6530 w.t.Log(string(p))
6531 return len(p), nil
6532 }
6533
6534 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6535 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6536 }
6537 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6538 const wantBody = "want"
6539 const wantUpgrade = "someProto"
6540 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6541 w.Header().Set("Connection", "Upgrade")
6542 w.Header().Set("Upgrade", wantUpgrade)
6543 w.WriteHeader(StatusSwitchingProtocols)
6544 NewResponseController(w).Flush()
6545
6546
6547 w.WriteHeader(200)
6548 if _, err := w.Write([]byte("x")); err == nil {
6549 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6550 }
6551
6552 c, _, err := NewResponseController(w).Hijack()
6553 if err != nil {
6554 t.Errorf("Hijack: %v", err)
6555 return
6556 }
6557 defer c.Close()
6558 if _, err := c.Write([]byte(wantBody)); err != nil {
6559 t.Errorf("Write to hijacked body: %v", err)
6560 }
6561 }), func(ts *httptest.Server) {
6562
6563 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6564 }).ts
6565
6566 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6567 if err != nil {
6568 t.Fatalf("net.Dial: %v", err)
6569 }
6570 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6571 if err != nil {
6572 t.Fatalf("conn.Write: %v", err)
6573 }
6574 defer conn.Close()
6575
6576 r := bufio.NewReader(conn)
6577 res, err := ReadResponse(r, &Request{Method: "GET"})
6578 if err != nil {
6579 t.Fatal("ReadResponse error:", err)
6580 }
6581 if res.StatusCode != StatusSwitchingProtocols {
6582 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6583 }
6584 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6585 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6586 }
6587 body, err := io.ReadAll(r)
6588 if err != nil {
6589 t.Error(err)
6590 }
6591 if string(body) != wantBody {
6592 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6593 }
6594 }
6595
6596 func TestMuxRedirectRelative(t *testing.T) {
6597 setParallel(t)
6598 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6599 if err != nil {
6600 t.Errorf("%s", err)
6601 }
6602 mux := NewServeMux()
6603 resp := httptest.NewRecorder()
6604 mux.ServeHTTP(resp, req)
6605 if got, want := resp.Header().Get("Location"), "/"; got != want {
6606 t.Errorf("Location header expected %q; got %q", want, got)
6607 }
6608 if got, want := resp.Code, StatusMovedPermanently; got != want {
6609 t.Errorf("Expected response code %d; got %d", want, got)
6610 }
6611 }
6612
6613
6614 func TestQuerySemicolon(t *testing.T) {
6615 t.Cleanup(func() { afterTest(t) })
6616
6617 tests := []struct {
6618 query string
6619 xNoSemicolons string
6620 xWithSemicolons string
6621 expectParseFormErr bool
6622 }{
6623 {"?a=1;x=bad&x=good", "good", "bad", true},
6624 {"?a=1;b=bad&x=good", "good", "good", true},
6625 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6626 {"?a=1;x=good;x=bad", "", "good", true},
6627 }
6628
6629 run(t, func(t *testing.T, mode testMode) {
6630 for _, tt := range tests {
6631 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6632 allowSemicolons := false
6633 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6634 })
6635 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6636 allowSemicolons, expectParseFormErr := true, false
6637 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6638 })
6639 }
6640 })
6641 }
6642
6643 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6644 writeBackX := func(w ResponseWriter, r *Request) {
6645 x := r.URL.Query().Get("x")
6646 if expectParseFormErr {
6647 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6648 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6649 }
6650 } else {
6651 if err := r.ParseForm(); err != nil {
6652 t.Errorf("expected no error from ParseForm, got %v", err)
6653 }
6654 }
6655 if got := r.FormValue("x"); x != got {
6656 t.Errorf("got %q from FormValue, want %q", got, x)
6657 }
6658 fmt.Fprintf(w, "%s", x)
6659 }
6660
6661 h := Handler(HandlerFunc(writeBackX))
6662 if allowSemicolons {
6663 h = AllowQuerySemicolons(h)
6664 }
6665
6666 logBuf := &strings.Builder{}
6667 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6668 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6669 }).ts
6670
6671 req, _ := NewRequest("GET", ts.URL+query, nil)
6672 res, err := ts.Client().Do(req)
6673 if err != nil {
6674 t.Fatal(err)
6675 }
6676 slurp, _ := io.ReadAll(res.Body)
6677 res.Body.Close()
6678 if got, want := res.StatusCode, 200; got != want {
6679 t.Errorf("Status = %d; want = %d", got, want)
6680 }
6681 if got, want := string(slurp), wantX; got != want {
6682 t.Errorf("Body = %q; want = %q", got, want)
6683 }
6684 }
6685
6686 func TestMaxBytesHandler(t *testing.T) {
6687
6688 defer afterTest(t)
6689
6690 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6691 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6692 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6693 func(t *testing.T) {
6694 run(t, func(t *testing.T, mode testMode) {
6695 testMaxBytesHandler(t, mode, maxSize, requestSize)
6696 }, testNotParallel)
6697 })
6698 }
6699 }
6700 }
6701
6702 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6703 runTimeSensitiveTest(t, []time.Duration{
6704 1 * time.Millisecond,
6705 5 * time.Millisecond,
6706 10 * time.Millisecond,
6707 50 * time.Millisecond,
6708 100 * time.Millisecond,
6709 500 * time.Millisecond,
6710 time.Second,
6711 5 * time.Second,
6712 }, func(t *testing.T, timeout time.Duration) error {
6713 SetRSTAvoidanceDelay(t, timeout)
6714 t.Logf("set RST avoidance delay to %v", timeout)
6715
6716 var (
6717 handlerN int64
6718 handlerErr error
6719 )
6720 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6721 var buf bytes.Buffer
6722 handlerN, handlerErr = io.Copy(&buf, r.Body)
6723 io.Copy(w, &buf)
6724 })
6725
6726 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6727
6728
6729 defer cst.close()
6730 ts := cst.ts
6731 c := ts.Client()
6732
6733 body := strings.Repeat("a", int(requestSize))
6734 var wg sync.WaitGroup
6735 defer wg.Wait()
6736 getBody := func() (io.ReadCloser, error) {
6737 wg.Add(1)
6738 body := &wgReadCloser{
6739 Reader: strings.NewReader(body),
6740 wg: &wg,
6741 }
6742 return body, nil
6743 }
6744 reqBody, _ := getBody()
6745 req, err := NewRequest("POST", ts.URL, reqBody)
6746 if err != nil {
6747 reqBody.Close()
6748 t.Fatal(err)
6749 }
6750 req.ContentLength = int64(len(body))
6751 req.GetBody = getBody
6752 req.Header.Set("Content-Type", "text/plain")
6753
6754 var buf strings.Builder
6755 res, err := c.Do(req)
6756 if err != nil {
6757 return fmt.Errorf("unexpected connection error: %v", err)
6758 } else {
6759 _, err = io.Copy(&buf, res.Body)
6760 res.Body.Close()
6761 if err != nil {
6762 return fmt.Errorf("unexpected read error: %v", err)
6763 }
6764 }
6765
6766
6767
6768
6769 if handlerN > maxSize {
6770 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6771 }
6772 if requestSize > maxSize && handlerErr == nil {
6773 t.Error("expected error on handler side; got nil")
6774 }
6775 if requestSize <= maxSize {
6776 if handlerErr != nil {
6777 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6778 }
6779 if handlerN != requestSize {
6780 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6781 }
6782 }
6783 if buf.Len() != int(handlerN) {
6784 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6785 }
6786
6787 return nil
6788 })
6789 }
6790
6791 func TestEarlyHints(t *testing.T) {
6792 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6793 h := w.Header()
6794 h.Add("Link", "</style.css>; rel=preload; as=style")
6795 h.Add("Link", "</script.js>; rel=preload; as=script")
6796 w.WriteHeader(StatusEarlyHints)
6797
6798 h.Add("Link", "</foo.js>; rel=preload; as=script")
6799 w.WriteHeader(StatusEarlyHints)
6800
6801 w.Write([]byte("stuff"))
6802 }))
6803
6804 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6805 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
6806 if !strings.Contains(got, expected) {
6807 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6808 }
6809 }
6810 func TestProcessing(t *testing.T) {
6811 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6812 w.WriteHeader(StatusProcessing)
6813 w.Write([]byte("stuff"))
6814 }))
6815
6816 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6817 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
6818 if !strings.Contains(got, expected) {
6819 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6820 }
6821 }
6822
6823 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
6824 func testParseFormCleanup(t *testing.T, mode testMode) {
6825 if mode == http2Mode {
6826 t.Skip("https://go.dev/issue/20253")
6827 }
6828
6829 const maxMemory = 1024
6830 const key = "file"
6831
6832 if runtime.GOOS == "windows" {
6833
6834 t.Skip("https://go.dev/issue/25965")
6835 }
6836
6837 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6838 r.ParseMultipartForm(maxMemory)
6839 f, _, err := r.FormFile(key)
6840 if err != nil {
6841 t.Errorf("r.FormFile(%q) = %v", key, err)
6842 return
6843 }
6844 of, ok := f.(*os.File)
6845 if !ok {
6846 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
6847 return
6848 }
6849 w.Write([]byte(of.Name()))
6850 }))
6851
6852 fBuf := new(bytes.Buffer)
6853 mw := multipart.NewWriter(fBuf)
6854 mf, err := mw.CreateFormFile(key, "myfile.txt")
6855 if err != nil {
6856 t.Fatal(err)
6857 }
6858 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
6859 t.Fatal(err)
6860 }
6861 if err := mw.Close(); err != nil {
6862 t.Fatal(err)
6863 }
6864 req, err := NewRequest("POST", cst.ts.URL, fBuf)
6865 if err != nil {
6866 t.Fatal(err)
6867 }
6868 req.Header.Set("Content-Type", mw.FormDataContentType())
6869 res, err := cst.c.Do(req)
6870 if err != nil {
6871 t.Fatal(err)
6872 }
6873 defer res.Body.Close()
6874 fname, err := io.ReadAll(res.Body)
6875 if err != nil {
6876 t.Fatal(err)
6877 }
6878 cst.close()
6879 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
6880 t.Errorf("file %q exists after HTTP handler returned", string(fname))
6881 }
6882 }
6883
6884 func TestHeadBody(t *testing.T) {
6885 const identityMode = false
6886 const chunkedMode = true
6887 run(t, func(t *testing.T, mode testMode) {
6888 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
6889 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
6890 })
6891 }
6892
6893 func TestGetBody(t *testing.T) {
6894 const identityMode = false
6895 const chunkedMode = true
6896 run(t, func(t *testing.T, mode testMode) {
6897 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
6898 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
6899 })
6900 }
6901
6902 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
6903 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6904 b, err := io.ReadAll(r.Body)
6905 if err != nil {
6906 t.Errorf("server reading body: %v", err)
6907 return
6908 }
6909 w.Header().Set("X-Request-Body", string(b))
6910 w.Header().Set("Content-Length", "0")
6911 }))
6912 defer cst.close()
6913 for _, reqBody := range []string{
6914 "",
6915 "",
6916 "request_body",
6917 "",
6918 } {
6919 var bodyReader io.Reader
6920 if reqBody != "" {
6921 bodyReader = strings.NewReader(reqBody)
6922 if chunked {
6923 bodyReader = bufio.NewReader(bodyReader)
6924 }
6925 }
6926 req, err := NewRequest(method, cst.ts.URL, bodyReader)
6927 if err != nil {
6928 t.Fatal(err)
6929 }
6930 res, err := cst.c.Do(req)
6931 if err != nil {
6932 t.Fatal(err)
6933 }
6934 res.Body.Close()
6935 if got, want := res.StatusCode, 200; got != want {
6936 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
6937 }
6938 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
6939 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
6940 }
6941 }
6942 }
6943
6944
6945
6946 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
6947 func testDisableContentLength(t *testing.T, mode testMode) {
6948 if mode == http2Mode {
6949 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
6950 }
6951
6952 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6953 w.Header()["Content-Length"] = nil
6954 fmt.Fprintf(w, "OK")
6955 }))
6956
6957 res, err := noCL.c.Get(noCL.ts.URL)
6958 if err != nil {
6959 t.Fatal(err)
6960 }
6961 if got, haveCL := res.Header["Content-Length"]; haveCL {
6962 t.Errorf("Unexpected Content-Length: %q", got)
6963 }
6964 if err := res.Body.Close(); err != nil {
6965 t.Fatal(err)
6966 }
6967
6968 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6969 fmt.Fprintf(w, "OK")
6970 }))
6971
6972 res, err = withCL.c.Get(withCL.ts.URL)
6973 if err != nil {
6974 t.Fatal(err)
6975 }
6976 if got := res.Header.Get("Content-Length"); got != "2" {
6977 t.Errorf("Content-Length: %q; want 2", got)
6978 }
6979 if err := res.Body.Close(); err != nil {
6980 t.Fatal(err)
6981 }
6982 }
6983
View as plain text