Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60
61
62
63 if c, ok := ResponseWriterConnForTesting(w); ok {
64 fmt.Fprintf(w, ", %T %p", c, c)
65 }
66 })
67
68
69 type testCloseConn struct {
70 net.Conn
71 set *testConnSet
72 }
73
74 func (c *testCloseConn) Close() error {
75 c.set.remove(c)
76 return c.Conn.Close()
77 }
78
79
80
81 type testConnSet struct {
82 t *testing.T
83 mu sync.Mutex
84 closed map[net.Conn]bool
85 list []net.Conn
86 }
87
88 func (tcs *testConnSet) insert(c net.Conn) {
89 tcs.mu.Lock()
90 defer tcs.mu.Unlock()
91 tcs.closed[c] = false
92 tcs.list = append(tcs.list, c)
93 }
94
95 func (tcs *testConnSet) remove(c net.Conn) {
96 tcs.mu.Lock()
97 defer tcs.mu.Unlock()
98 tcs.closed[c] = true
99 }
100
101
102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103 connSet := &testConnSet{
104 t: t,
105 closed: make(map[net.Conn]bool),
106 }
107 dial := func(n, addr string) (net.Conn, error) {
108 c, err := net.Dial(n, addr)
109 if err != nil {
110 return nil, err
111 }
112 tc := &testCloseConn{c, connSet}
113 connSet.insert(tc)
114 return tc, nil
115 }
116 return connSet, dial
117 }
118
119 func (tcs *testConnSet) check(t *testing.T) {
120 tcs.mu.Lock()
121 defer tcs.mu.Unlock()
122 for i := 4; i >= 0; i-- {
123 for i, c := range tcs.list {
124 if tcs.closed[c] {
125 continue
126 }
127 if i != 0 {
128
129
130 tcs.mu.Unlock()
131 time.Sleep(50 * time.Millisecond)
132 tcs.mu.Lock()
133 continue
134 }
135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136 }
137 }
138 }
139
140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141 func testReuseRequest(t *testing.T, mode testMode) {
142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143 w.Write([]byte("{}"))
144 })).ts
145
146 c := ts.Client()
147 req, _ := NewRequest("GET", ts.URL, nil)
148 res, err := c.Do(req)
149 if err != nil {
150 t.Fatal(err)
151 }
152 err = res.Body.Close()
153 if err != nil {
154 t.Fatal(err)
155 }
156
157 res, err = c.Do(req)
158 if err != nil {
159 t.Fatal(err)
160 }
161 err = res.Body.Close()
162 if err != nil {
163 t.Fatal(err)
164 }
165 }
166
167
168
169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170 func testTransportKeepAlives(t *testing.T, mode testMode) {
171 ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173 c := ts.Client()
174 for _, disableKeepAlive := range []bool{false, true} {
175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176 fetch := func(n int) string {
177 res, err := c.Get(ts.URL)
178 if err != nil {
179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180 }
181 body, err := io.ReadAll(res.Body)
182 if err != nil {
183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184 }
185 return string(body)
186 }
187
188 body1 := fetch(1)
189 body2 := fetch(2)
190
191 bodiesDiffer := body1 != body2
192 if bodiesDiffer != disableKeepAlive {
193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194 disableKeepAlive, bodiesDiffer, body1, body2)
195 }
196 }
197 }
198
199 func TestTransportConnectionCloseOnResponse(t *testing.T) {
200 run(t, testTransportConnectionCloseOnResponse)
201 }
202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203 ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205 connSet, testDial := makeTestDial(t)
206
207 c := ts.Client()
208 tr := c.Transport.(*Transport)
209 tr.Dial = testDial
210
211 for _, connectionClose := range []bool{false, true} {
212 fetch := func(n int) string {
213 req := new(Request)
214 var err error
215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216 if err != nil {
217 t.Fatalf("URL parse error: %v", err)
218 }
219 req.Method = "GET"
220 req.Proto = "HTTP/1.1"
221 req.ProtoMajor = 1
222 req.ProtoMinor = 1
223
224 res, err := c.Do(req)
225 if err != nil {
226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227 }
228 defer res.Body.Close()
229 body, err := io.ReadAll(res.Body)
230 if err != nil {
231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232 }
233 return string(body)
234 }
235
236 body1 := fetch(1)
237 body2 := fetch(2)
238 bodiesDiffer := body1 != body2
239 if bodiesDiffer != connectionClose {
240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241 connectionClose, bodiesDiffer, body1, body2)
242 }
243
244 tr.CloseIdleConnections()
245 }
246
247 connSet.check(t)
248 }
249
250
251
252
253
254
255
256 func TestTransportConnectionCloseOnRequest(t *testing.T) {
257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258 }
259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260 ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262 connSet, testDial := makeTestDial(t)
263
264 c := ts.Client()
265 tr := c.Transport.(*Transport)
266 tr.Dial = testDial
267 for _, reqClose := range []bool{false, true} {
268 fetch := func(n int) string {
269 req := new(Request)
270 var err error
271 req.URL, err = url.Parse(ts.URL)
272 if err != nil {
273 t.Fatalf("URL parse error: %v", err)
274 }
275 req.Method = "GET"
276 req.Proto = "HTTP/1.1"
277 req.ProtoMajor = 1
278 req.ProtoMinor = 1
279 req.Close = reqClose
280
281 res, err := c.Do(req)
282 if err != nil {
283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284 }
285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287 reqClose, got, !reqClose)
288 }
289 body, err := io.ReadAll(res.Body)
290 if err != nil {
291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292 }
293 return string(body)
294 }
295
296 body1 := fetch(1)
297 body2 := fetch(2)
298
299 got := 1
300 if body1 != body2 {
301 got++
302 }
303 want := 1
304 if reqClose {
305 want = 2
306 }
307 if got != want {
308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309 reqClose, got, want, body1, body2)
310 }
311
312 tr.CloseIdleConnections()
313 }
314
315 connSet.check(t)
316 }
317
318
319
320
321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323 }
324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325 ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327 c := ts.Client()
328 c.Transport.(*Transport).DisableKeepAlives = true
329
330 res, err := c.Get(ts.URL)
331 if err != nil {
332 t.Fatal(err)
333 }
334 res.Body.Close()
335 if res.Header.Get("X-Saw-Close") != "true" {
336 t.Errorf("handler didn't see Connection: close ")
337 }
338 }
339
340
341
342 func TestTransportRespectRequestWantsClose(t *testing.T) {
343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344 }
345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346 tests := []struct {
347 disableKeepAlives bool
348 close bool
349 }{
350 {disableKeepAlives: false, close: false},
351 {disableKeepAlives: false, close: true},
352 {disableKeepAlives: true, close: false},
353 {disableKeepAlives: true, close: true},
354 }
355
356 for _, tc := range tests {
357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358 func(t *testing.T) {
359 ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361 c := ts.Client()
362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363 req, err := NewRequest("GET", ts.URL, nil)
364 if err != nil {
365 t.Fatal(err)
366 }
367 count := 0
368 trace := &httptrace.ClientTrace{
369 WroteHeaderField: func(key string, field []string) {
370 if key != "Connection" {
371 return
372 }
373 if httpguts.HeaderValuesContainsToken(field, "close") {
374 count += 1
375 }
376 },
377 }
378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379 req.Close = tc.close
380 res, err := c.Do(req)
381 if err != nil {
382 t.Fatal(err)
383 }
384 defer res.Body.Close()
385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387 }
388 })
389 }
390
391 }
392
393 func TestTransportIdleCacheKeys(t *testing.T) {
394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395 }
396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397 ts := newClientServerTest(t, mode, hostPortHandler).ts
398 c := ts.Client()
399 tr := c.Transport.(*Transport)
400
401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403 }
404
405 resp, err := c.Get(ts.URL)
406 if err != nil {
407 t.Error(err)
408 }
409 io.ReadAll(resp.Body)
410
411 keys := tr.IdleConnKeysForTesting()
412 if e, g := 1, len(keys); e != g {
413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414 }
415
416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418 }
419
420 tr.CloseIdleConnections()
421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423 }
424 }
425
426
427
428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430 const msg = "foobar"
431
432 var addrSeen map[string]int
433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434 addrSeen[r.RemoteAddr]++
435 if r.URL.Path == "/chunked/" {
436 w.WriteHeader(200)
437 w.(Flusher).Flush()
438 } else {
439 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440 w.WriteHeader(200)
441 }
442 w.Write([]byte(msg))
443 })).ts
444
445 for pi, path := range []string{"/content-length/", "/chunked/"} {
446 wantLen := []int{len(msg), -1}[pi]
447 addrSeen = make(map[string]int)
448 for i := 0; i < 3; i++ {
449 res, err := ts.Client().Get(ts.URL + path)
450 if err != nil {
451 t.Errorf("Get %s: %v", path, err)
452 continue
453 }
454
455
456
457
458
459 defer res.Body.Close()
460
461 if res.ContentLength != int64(wantLen) {
462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463 }
464 got, err := io.ReadAll(res.Body)
465 if string(got) != msg || err != nil {
466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467 }
468 }
469 if len(addrSeen) != 1 {
470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471 }
472 }
473 }
474
475 func TestTransportMaxPerHostIdleConns(t *testing.T) {
476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477 }
478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479 stop := make(chan struct{})
480 defer close(stop)
481
482 resch := make(chan string)
483 gotReq := make(chan bool)
484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485 gotReq <- true
486 var msg string
487 select {
488 case <-stop:
489 return
490 case msg = <-resch:
491 }
492 _, err := w.Write([]byte(msg))
493 if err != nil {
494 t.Errorf("Write: %v", err)
495 return
496 }
497 })).ts
498
499 c := ts.Client()
500 tr := c.Transport.(*Transport)
501 maxIdleConnsPerHost := 2
502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504
505
506 donech := make(chan bool)
507 doReq := func() {
508 defer func() {
509 select {
510 case <-stop:
511 return
512 case donech <- t.Failed():
513 }
514 }()
515 resp, err := c.Get(ts.URL)
516 if err != nil {
517 t.Error(err)
518 return
519 }
520 if _, err := io.ReadAll(resp.Body); err != nil {
521 t.Errorf("ReadAll: %v", err)
522 return
523 }
524 }
525 go doReq()
526 <-gotReq
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531
532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534 }
535
536 resch <- "res1"
537 <-donech
538 keys := tr.IdleConnKeysForTesting()
539 if e, g := 1, len(keys); e != g {
540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541 }
542 addr := ts.Listener.Addr().String()
543 cacheKey := "|http|" + addr
544 if keys[0] != cacheKey {
545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546 }
547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549 }
550
551 resch <- "res2"
552 <-donech
553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554 t.Errorf("after second response, idle conns = %d; want %d", g, w)
555 }
556
557 resch <- "res3"
558 <-donech
559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560 t.Errorf("after third response, idle conns = %d; want %d", g, w)
561 }
562 }
563
564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566 }
567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569 _, err := w.Write([]byte("foo"))
570 if err != nil {
571 t.Fatalf("Write: %v", err)
572 }
573 })).ts
574 c := ts.Client()
575 tr := c.Transport.(*Transport)
576 dialStarted := make(chan struct{})
577 stallDial := make(chan struct{})
578 tr.Dial = func(network, addr string) (net.Conn, error) {
579 dialStarted <- struct{}{}
580 <-stallDial
581 return net.Dial(network, addr)
582 }
583
584 tr.DisableKeepAlives = true
585 tr.MaxConnsPerHost = 1
586
587 preDial := make(chan struct{})
588 reqComplete := make(chan struct{})
589 doReq := func(reqId string) {
590 req, _ := NewRequest("GET", ts.URL, nil)
591 trace := &httptrace.ClientTrace{
592 GetConn: func(hostPort string) {
593 preDial <- struct{}{}
594 },
595 }
596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597 resp, err := tr.RoundTrip(req)
598 if err != nil {
599 t.Errorf("unexpected error for request %s: %v", reqId, err)
600 }
601 _, err = io.ReadAll(resp.Body)
602 if err != nil {
603 t.Errorf("unexpected error for request %s: %v", reqId, err)
604 }
605 reqComplete <- struct{}{}
606 }
607
608 go doReq("req1")
609 <-preDial
610 <-dialStarted
611
612
613 go doReq("req2")
614 <-preDial
615 select {
616 case <-dialStarted:
617 t.Error("req2 dial started while req1 dial in progress")
618 return
619 default:
620 }
621
622
623 stallDial <- struct{}{}
624 <-reqComplete
625
626
627 <-dialStarted
628 stallDial <- struct{}{}
629 <-reqComplete
630 }
631
632 func TestTransportMaxConnsPerHost(t *testing.T) {
633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634 }
635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636 CondSkipHTTP2(t)
637
638 h := HandlerFunc(func(w ResponseWriter, r *Request) {
639 _, err := w.Write([]byte("foo"))
640 if err != nil {
641 t.Fatalf("Write: %v", err)
642 }
643 })
644
645 ts := newClientServerTest(t, mode, h).ts
646 c := ts.Client()
647 tr := c.Transport.(*Transport)
648 tr.MaxConnsPerHost = 1
649
650 mu := sync.Mutex{}
651 var conns []net.Conn
652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653 tr.Dial = func(network, addr string) (net.Conn, error) {
654 atomic.AddInt32(&dialCnt, 1)
655 c, err := net.Dial(network, addr)
656 mu.Lock()
657 defer mu.Unlock()
658 conns = append(conns, c)
659 return c, err
660 }
661
662 doReq := func() {
663 trace := &httptrace.ClientTrace{
664 GotConn: func(connInfo httptrace.GotConnInfo) {
665 if !connInfo.Reused {
666 atomic.AddInt32(&gotConnCnt, 1)
667 }
668 },
669 TLSHandshakeStart: func() {
670 atomic.AddInt32(&tlsHandshakeCnt, 1)
671 },
672 }
673 req, _ := NewRequest("GET", ts.URL, nil)
674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676 resp, err := c.Do(req)
677 if err != nil {
678 t.Fatalf("request failed: %v", err)
679 }
680 defer resp.Body.Close()
681 _, err = io.ReadAll(resp.Body)
682 if err != nil {
683 t.Fatalf("read body failed: %v", err)
684 }
685 }
686
687 wg := sync.WaitGroup{}
688 for i := 0; i < 10; i++ {
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 doReq()
693 }()
694 }
695 wg.Wait()
696
697 expected := int32(tr.MaxConnsPerHost)
698 if dialCnt != expected {
699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700 }
701 if gotConnCnt != expected {
702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703 }
704 if ts.TLS != nil && tlsHandshakeCnt != expected {
705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706 }
707
708 if t.Failed() {
709 t.FailNow()
710 }
711
712 mu.Lock()
713 for _, c := range conns {
714 c.Close()
715 }
716 conns = nil
717 mu.Unlock()
718 tr.CloseIdleConnections()
719
720 doReq()
721 expected++
722 if dialCnt != expected {
723 t.Errorf("round 2: too many dials: %d", dialCnt)
724 }
725 if gotConnCnt != expected {
726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727 }
728 if ts.TLS != nil && tlsHandshakeCnt != expected {
729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730 }
731 }
732
733 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
734 run(t, testTransportMaxConnsPerHostDialCancellation,
735 testNotParallel,
736 []testMode{http1Mode, https1Mode, http2Mode},
737 )
738 }
739
740 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
741 CondSkipHTTP2(t)
742
743 h := HandlerFunc(func(w ResponseWriter, r *Request) {
744 _, err := w.Write([]byte("foo"))
745 if err != nil {
746 t.Fatalf("Write: %v", err)
747 }
748 })
749
750 cst := newClientServerTest(t, mode, h)
751 defer cst.close()
752 ts := cst.ts
753 c := ts.Client()
754 tr := c.Transport.(*Transport)
755 tr.MaxConnsPerHost = 1
756
757
758 ctx, cancel := context.WithCancel(context.Background())
759 defer cancel()
760 SetPendingDialHooks(cancel, nil)
761 defer SetPendingDialHooks(nil, nil)
762
763 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
764 _, err := c.Do(req)
765 if !errors.Is(err, context.Canceled) {
766 t.Errorf("expected error %v, got %v", context.Canceled, err)
767 }
768
769
770 SetPendingDialHooks(nil, nil)
771 req, _ = NewRequest("GET", ts.URL, nil)
772 resp, err := c.Do(req)
773 if err != nil {
774 t.Fatalf("request failed: %v", err)
775 }
776 defer resp.Body.Close()
777 _, err = io.ReadAll(resp.Body)
778 if err != nil {
779 t.Fatalf("read body failed: %v", err)
780 }
781 }
782
783 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
784 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
785 }
786 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788 io.WriteString(w, r.RemoteAddr)
789 })).ts
790
791 c := ts.Client()
792 tr := c.Transport.(*Transport)
793
794 doReq := func(name string) {
795
796
797 res, err := c.Post(ts.URL, "", nil)
798 if err != nil {
799 t.Fatalf("%s: %v", name, err)
800 }
801 if res.StatusCode != 200 {
802 t.Fatalf("%s: %v", name, res.Status)
803 }
804 defer res.Body.Close()
805 slurp, err := io.ReadAll(res.Body)
806 if err != nil {
807 t.Fatalf("%s: %v", name, err)
808 }
809 t.Logf("%s: ok (%q)", name, slurp)
810 }
811
812 doReq("first")
813 keys1 := tr.IdleConnKeysForTesting()
814
815 ts.CloseClientConnections()
816
817 var keys2 []string
818 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
819 keys2 = tr.IdleConnKeysForTesting()
820 if len(keys2) != 0 {
821 if d > 0 {
822 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
823 }
824 return false
825 }
826 return true
827 })
828
829 doReq("second")
830 }
831
832
833
834 func TestTransportServerClosingUnexpectedly(t *testing.T) {
835 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
836 }
837 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
838 ts := newClientServerTest(t, mode, hostPortHandler).ts
839 c := ts.Client()
840
841 fetch := func(n, retries int) string {
842 condFatalf := func(format string, arg ...any) {
843 if retries <= 0 {
844 t.Fatalf(format, arg...)
845 }
846 t.Logf("retrying shortly after expected error: "+format, arg...)
847 time.Sleep(time.Second / time.Duration(retries))
848 }
849 for retries >= 0 {
850 retries--
851 res, err := c.Get(ts.URL)
852 if err != nil {
853 condFatalf("error in req #%d, GET: %v", n, err)
854 continue
855 }
856 body, err := io.ReadAll(res.Body)
857 if err != nil {
858 condFatalf("error in req #%d, ReadAll: %v", n, err)
859 continue
860 }
861 res.Body.Close()
862 return string(body)
863 }
864 panic("unreachable")
865 }
866
867 body1 := fetch(1, 0)
868 body2 := fetch(2, 0)
869
870
871
872
873
874
875
876
877 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
878
879 body3 := fetch(3, 5)
880
881 if body1 != body2 {
882 t.Errorf("expected body1 and body2 to be equal")
883 }
884 if body2 == body3 {
885 t.Errorf("expected body2 and body3 to be different")
886 }
887 }
888
889
890
891 func TestStressSurpriseServerCloses(t *testing.T) {
892 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
893 }
894 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
895 if testing.Short() {
896 t.Skip("skipping test in short mode")
897 }
898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899 w.Header().Set("Content-Length", "5")
900 w.Header().Set("Content-Type", "text/plain")
901 w.Write([]byte("Hello"))
902 w.(Flusher).Flush()
903 conn, buf, _ := w.(Hijacker).Hijack()
904 buf.Flush()
905 conn.Close()
906 })).ts
907 c := ts.Client()
908
909
910
911
912
913
914
915 const (
916 numClients = 20
917 reqsPerClient = 25
918 )
919 var wg sync.WaitGroup
920 wg.Add(numClients * reqsPerClient)
921 for i := 0; i < numClients; i++ {
922 go func() {
923 for i := 0; i < reqsPerClient; i++ {
924 res, err := c.Get(ts.URL)
925 if err == nil {
926
927
928
929
930
931
932 res.Body.Close()
933 }
934 wg.Done()
935 }
936 }()
937 }
938
939
940 wg.Wait()
941 }
942
943
944
945 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
946 func testTransportHeadResponses(t *testing.T, mode testMode) {
947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948 if r.Method != "HEAD" {
949 panic("expected HEAD; got " + r.Method)
950 }
951 w.Header().Set("Content-Length", "123")
952 w.WriteHeader(200)
953 })).ts
954 c := ts.Client()
955
956 for i := 0; i < 2; i++ {
957 res, err := c.Head(ts.URL)
958 if err != nil {
959 t.Errorf("error on loop %d: %v", i, err)
960 continue
961 }
962 if e, g := "123", res.Header.Get("Content-Length"); e != g {
963 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
964 }
965 if e, g := int64(123), res.ContentLength; e != g {
966 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
967 }
968 if all, err := io.ReadAll(res.Body); err != nil {
969 t.Errorf("loop %d: Body ReadAll: %v", i, err)
970 } else if len(all) != 0 {
971 t.Errorf("Bogus body %q", all)
972 }
973 }
974 }
975
976
977
978 func TestTransportHeadChunkedResponse(t *testing.T) {
979 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
980 }
981 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
983 if r.Method != "HEAD" {
984 panic("expected HEAD; got " + r.Method)
985 }
986 w.Header().Set("Transfer-Encoding", "chunked")
987 w.Header().Set("x-client-ipport", r.RemoteAddr)
988 w.WriteHeader(200)
989 })).ts
990 c := ts.Client()
991
992
993
994 didRead := make(chan bool)
995 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
996 defer SetReadLoopBeforeNextReadHook(nil)
997
998 res1, err := c.Head(ts.URL)
999 <-didRead
1000
1001 if err != nil {
1002 t.Fatalf("request 1 error: %v", err)
1003 }
1004
1005 res2, err := c.Head(ts.URL)
1006 <-didRead
1007
1008 if err != nil {
1009 t.Fatalf("request 2 error: %v", err)
1010 }
1011 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1012 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1013 }
1014 }
1015
1016 var roundTripTests = []struct {
1017 accept string
1018 expectAccept string
1019 compressed bool
1020 }{
1021
1022 {"", "gzip", false},
1023
1024 {"foo", "foo", false},
1025
1026 {"gzip", "gzip", true},
1027 }
1028
1029
1030 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1031 func testRoundTripGzip(t *testing.T, mode testMode) {
1032 const responseBody = "test response body"
1033 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1034 accept := req.Header.Get("Accept-Encoding")
1035 if expect := req.FormValue("expect_accept"); accept != expect {
1036 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1037 req.FormValue("testnum"), accept, expect)
1038 }
1039 if accept == "gzip" {
1040 rw.Header().Set("Content-Encoding", "gzip")
1041 gz := gzip.NewWriter(rw)
1042 gz.Write([]byte(responseBody))
1043 gz.Close()
1044 } else {
1045 rw.Header().Set("Content-Encoding", accept)
1046 rw.Write([]byte(responseBody))
1047 }
1048 })).ts
1049 tr := ts.Client().Transport.(*Transport)
1050
1051 for i, test := range roundTripTests {
1052
1053 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1054 if test.accept != "" {
1055 req.Header.Set("Accept-Encoding", test.accept)
1056 }
1057 res, err := tr.RoundTrip(req)
1058 if err != nil {
1059 t.Errorf("%d. RoundTrip: %v", i, err)
1060 continue
1061 }
1062 var body []byte
1063 if test.compressed {
1064 var r *gzip.Reader
1065 r, err = gzip.NewReader(res.Body)
1066 if err != nil {
1067 t.Errorf("%d. gzip NewReader: %v", i, err)
1068 continue
1069 }
1070 body, err = io.ReadAll(r)
1071 res.Body.Close()
1072 } else {
1073 body, err = io.ReadAll(res.Body)
1074 }
1075 if err != nil {
1076 t.Errorf("%d. Error: %q", i, err)
1077 continue
1078 }
1079 if g, e := string(body), responseBody; g != e {
1080 t.Errorf("%d. body = %q; want %q", i, g, e)
1081 }
1082 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1083 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1084 }
1085 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1086 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1087 }
1088 }
1089
1090 }
1091
1092 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1093 func testTransportGzip(t *testing.T, mode testMode) {
1094 if mode == http2Mode {
1095 t.Skip("https://go.dev/issue/56020")
1096 }
1097 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1098 const nRandBytes = 1024 * 1024
1099 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1100 if req.Method == "HEAD" {
1101 if g := req.Header.Get("Accept-Encoding"); g != "" {
1102 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1103 }
1104 return
1105 }
1106 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1107 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1108 }
1109 rw.Header().Set("Content-Encoding", "gzip")
1110
1111 var w io.Writer = rw
1112 var buf bytes.Buffer
1113 if req.FormValue("chunked") == "0" {
1114 w = &buf
1115 defer io.Copy(rw, &buf)
1116 defer func() {
1117 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1118 }()
1119 }
1120 gz := gzip.NewWriter(w)
1121 gz.Write([]byte(testString))
1122 if req.FormValue("body") == "large" {
1123 io.CopyN(gz, rand.Reader, nRandBytes)
1124 }
1125 gz.Close()
1126 })).ts
1127 c := ts.Client()
1128
1129 for _, chunked := range []string{"1", "0"} {
1130
1131 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1132 if err != nil {
1133 t.Fatalf("large get: %v", err)
1134 }
1135 buf := make([]byte, len(testString))
1136 n, err := io.ReadFull(res.Body, buf)
1137 if err != nil {
1138 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1139 }
1140 if e, g := testString, string(buf); e != g {
1141 t.Errorf("partial read got %q, expected %q", g, e)
1142 }
1143 res.Body.Close()
1144
1145 n, err = res.Body.Read(buf)
1146 if n != 0 || err == nil {
1147 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1148 }
1149
1150
1151 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1152 if err != nil {
1153 t.Fatal(err)
1154 }
1155 body, err := io.ReadAll(res.Body)
1156 if err != nil {
1157 t.Fatal(err)
1158 }
1159 if g, e := string(body), testString; g != e {
1160 t.Fatalf("body = %q; want %q", g, e)
1161 }
1162 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1163 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1164 }
1165
1166
1167 n, err = res.Body.Read(buf)
1168 if n != 0 || err == nil {
1169 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1170 }
1171 res.Body.Close()
1172 n, err = res.Body.Read(buf)
1173 if n != 0 || err == nil {
1174 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1175 }
1176 }
1177
1178
1179 res, err := c.Head(ts.URL)
1180 if err != nil {
1181 t.Fatalf("Head: %v", err)
1182 }
1183 if res.StatusCode != 200 {
1184 t.Errorf("Head status=%d; want=200", res.StatusCode)
1185 }
1186 }
1187
1188
1189
1190 type transport100ContinueTest struct {
1191 t *testing.T
1192
1193 reqdone chan struct{}
1194 resp *Response
1195 respErr error
1196
1197 conn net.Conn
1198 reader *bufio.Reader
1199 }
1200
1201 const transport100ContinueTestBody = "request body"
1202
1203
1204
1205 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1206 ln := newLocalListener(t)
1207 defer ln.Close()
1208
1209 test := &transport100ContinueTest{
1210 t: t,
1211 reqdone: make(chan struct{}),
1212 }
1213
1214 tr := &Transport{
1215 ExpectContinueTimeout: timeout,
1216 }
1217 go func() {
1218 defer close(test.reqdone)
1219 body := strings.NewReader(transport100ContinueTestBody)
1220 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1221 req.Header.Set("Expect", "100-continue")
1222 req.ContentLength = int64(len(transport100ContinueTestBody))
1223 test.resp, test.respErr = tr.RoundTrip(req)
1224 test.resp.Body.Close()
1225 }()
1226
1227 c, err := ln.Accept()
1228 if err != nil {
1229 t.Fatalf("Accept: %v", err)
1230 }
1231 t.Cleanup(func() {
1232 c.Close()
1233 })
1234 br := bufio.NewReader(c)
1235 _, err = ReadRequest(br)
1236 if err != nil {
1237 t.Fatalf("ReadRequest: %v", err)
1238 }
1239 test.conn = c
1240 test.reader = br
1241 t.Cleanup(func() {
1242 <-test.reqdone
1243 tr.CloseIdleConnections()
1244 got, _ := io.ReadAll(test.reader)
1245 if len(got) > 0 {
1246 t.Fatalf("Transport sent unexpected bytes: %q", got)
1247 }
1248 })
1249
1250 return test
1251 }
1252
1253
1254 func (test *transport100ContinueTest) respond(lines ...string) {
1255 for _, line := range lines {
1256 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1257 test.t.Fatalf("Write: %v", err)
1258 }
1259 }
1260 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1261 test.t.Fatalf("Write: %v", err)
1262 }
1263 }
1264
1265
1266 func (test *transport100ContinueTest) wantBodySent() {
1267 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1268 if err != nil {
1269 test.t.Fatalf("unexpected error reading body: %v", err)
1270 }
1271 if got, want := string(got), transport100ContinueTestBody; got != want {
1272 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1273 }
1274 }
1275
1276
1277 func (test *transport100ContinueTest) wantRequestDone(want int) {
1278 <-test.reqdone
1279 if test.respErr != nil {
1280 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1281 }
1282 if got := test.resp.StatusCode; got != want {
1283 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1284 }
1285 }
1286
1287 func TestTransportExpect100ContinueSent(t *testing.T) {
1288 test := newTransport100ContinueTest(t, 1*time.Hour)
1289
1290 test.respond("HTTP/1.1 100 Continue")
1291 test.wantBodySent()
1292 test.respond("HTTP/1.1 200", "Content-Length: 0")
1293 test.wantRequestDone(200)
1294 }
1295
1296 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1297 test := newTransport100ContinueTest(t, 1*time.Hour)
1298
1299 test.respond("HTTP/1.1 200", "Content-Length: 0")
1300 test.wantBodySent()
1301 test.wantRequestDone(200)
1302 }
1303
1304 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1305 test := newTransport100ContinueTest(t, 1*time.Hour)
1306
1307 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1308 test.wantRequestDone(200)
1309 }
1310
1311 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1312 test := newTransport100ContinueTest(t, 1*time.Hour)
1313
1314 test.respond("HTTP/1.1 500", "Content-Length: 0")
1315 test.wantBodySent()
1316 test.wantRequestDone(500)
1317 }
1318
1319 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1320 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1321 test.wantBodySent()
1322 test.respond("HTTP/1.1 200", "Content-Length: 0")
1323 test.wantRequestDone(200)
1324 }
1325
1326 func TestSOCKS5Proxy(t *testing.T) {
1327 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1328 }
1329 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1330 ch := make(chan string, 1)
1331 l := newLocalListener(t)
1332 defer l.Close()
1333 defer close(ch)
1334 proxy := func(t *testing.T) {
1335 s, err := l.Accept()
1336 if err != nil {
1337 t.Errorf("socks5 proxy Accept(): %v", err)
1338 return
1339 }
1340 defer s.Close()
1341 var buf [22]byte
1342 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1343 t.Errorf("socks5 proxy initial read: %v", err)
1344 return
1345 }
1346 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1347 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1348 return
1349 }
1350 if _, err := s.Write([]byte{5, 0}); err != nil {
1351 t.Errorf("socks5 proxy initial write: %v", err)
1352 return
1353 }
1354 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1355 t.Errorf("socks5 proxy second read: %v", err)
1356 return
1357 }
1358 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1359 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1360 return
1361 }
1362 var ipLen int
1363 switch buf[3] {
1364 case 1:
1365 ipLen = net.IPv4len
1366 case 4:
1367 ipLen = net.IPv6len
1368 default:
1369 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1370 return
1371 }
1372 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1373 t.Errorf("socks5 proxy address read: %v", err)
1374 return
1375 }
1376 ip := net.IP(buf[4 : ipLen+4])
1377 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1378 copy(buf[:3], []byte{5, 0, 0})
1379 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1380 t.Errorf("socks5 proxy connect write: %v", err)
1381 return
1382 }
1383 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1384
1385
1386 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1387 targetConn, err := net.Dial("tcp", targetHost)
1388 if err != nil {
1389 t.Errorf("net.Dial failed")
1390 return
1391 }
1392 go io.Copy(targetConn, s)
1393 io.Copy(s, targetConn)
1394 targetConn.Close()
1395 }
1396
1397 pu, err := url.Parse("socks5://" + l.Addr().String())
1398 if err != nil {
1399 t.Fatal(err)
1400 }
1401
1402 sentinelHeader := "X-Sentinel"
1403 sentinelValue := "12345"
1404 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1405 w.Header().Set(sentinelHeader, sentinelValue)
1406 })
1407 for _, useTLS := range []bool{false, true} {
1408 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1409 ts := newClientServerTest(t, mode, h).ts
1410 go proxy(t)
1411 c := ts.Client()
1412 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1413 r, err := c.Head(ts.URL)
1414 if err != nil {
1415 t.Fatal(err)
1416 }
1417 if r.Header.Get(sentinelHeader) != sentinelValue {
1418 t.Errorf("Failed to retrieve sentinel value")
1419 }
1420 got := <-ch
1421 ts.Close()
1422 tsu, err := url.Parse(ts.URL)
1423 if err != nil {
1424 t.Fatal(err)
1425 }
1426 want := "proxy for " + tsu.Host
1427 if got != want {
1428 t.Errorf("got %q, want %q", got, want)
1429 }
1430 })
1431 }
1432 }
1433
1434 func TestTransportProxy(t *testing.T) {
1435 defer afterTest(t)
1436 testCases := []struct{ siteMode, proxyMode testMode }{
1437 {http1Mode, http1Mode},
1438 {http1Mode, https1Mode},
1439 {https1Mode, http1Mode},
1440 {https1Mode, https1Mode},
1441 }
1442 for _, testCase := range testCases {
1443 siteMode := testCase.siteMode
1444 proxyMode := testCase.proxyMode
1445 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1446 siteCh := make(chan *Request, 1)
1447 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1448 siteCh <- r
1449 })
1450 proxyCh := make(chan *Request, 1)
1451 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1452 proxyCh <- r
1453
1454 if r.Method == "CONNECT" {
1455 hijacker, ok := w.(Hijacker)
1456 if !ok {
1457 t.Errorf("hijack not allowed")
1458 return
1459 }
1460 clientConn, _, err := hijacker.Hijack()
1461 if err != nil {
1462 t.Errorf("hijacking failed")
1463 return
1464 }
1465 res := &Response{
1466 StatusCode: StatusOK,
1467 Proto: "HTTP/1.1",
1468 ProtoMajor: 1,
1469 ProtoMinor: 1,
1470 Header: make(Header),
1471 }
1472
1473 targetConn, err := net.Dial("tcp", r.URL.Host)
1474 if err != nil {
1475 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1476 return
1477 }
1478
1479 if err := res.Write(clientConn); err != nil {
1480 t.Errorf("Writing 200 OK failed: %v", err)
1481 return
1482 }
1483
1484 go io.Copy(targetConn, clientConn)
1485 go func() {
1486 io.Copy(clientConn, targetConn)
1487 targetConn.Close()
1488 }()
1489 }
1490 })
1491 ts := newClientServerTest(t, siteMode, h1).ts
1492 proxy := newClientServerTest(t, proxyMode, h2).ts
1493
1494 pu, err := url.Parse(proxy.URL)
1495 if err != nil {
1496 t.Fatal(err)
1497 }
1498
1499
1500
1501
1502 c := proxy.Client()
1503 if siteMode == https1Mode {
1504 c = ts.Client()
1505 }
1506
1507 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1508 if _, err := c.Head(ts.URL); err != nil {
1509 t.Error(err)
1510 }
1511 got := <-proxyCh
1512 c.Transport.(*Transport).CloseIdleConnections()
1513 ts.Close()
1514 proxy.Close()
1515 if siteMode == https1Mode {
1516
1517 if got.Method != "CONNECT" {
1518 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1519 }
1520 gotHost := got.URL.Host
1521 pu, err := url.Parse(ts.URL)
1522 if err != nil {
1523 t.Fatal("Invalid site URL")
1524 }
1525 if wantHost := pu.Host; gotHost != wantHost {
1526 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1527 }
1528
1529
1530 next := <-siteCh
1531 if next.Method != "HEAD" {
1532 t.Errorf("Wrong method at destination: %s", next.Method)
1533 }
1534 if nextURL := next.URL.String(); nextURL != "/" {
1535 t.Errorf("Wrong URL at destination: %s", nextURL)
1536 }
1537 } else {
1538 if got.Method != "HEAD" {
1539 t.Errorf("Wrong method for destination: %q", got.Method)
1540 }
1541 gotURL := got.URL.String()
1542 wantURL := ts.URL + "/"
1543 if gotURL != wantURL {
1544 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1545 }
1546 }
1547 })
1548 }
1549 }
1550
1551 func TestOnProxyConnectResponse(t *testing.T) {
1552
1553 var tcases = []struct {
1554 proxyStatusCode int
1555 err error
1556 }{
1557 {
1558 StatusOK,
1559 nil,
1560 },
1561 {
1562 StatusForbidden,
1563 errors.New("403"),
1564 },
1565 }
1566 for _, tcase := range tcases {
1567 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1568
1569 })
1570
1571 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1572
1573 if r.Method == "CONNECT" {
1574 if tcase.proxyStatusCode != StatusOK {
1575 w.WriteHeader(tcase.proxyStatusCode)
1576 return
1577 }
1578 hijacker, ok := w.(Hijacker)
1579 if !ok {
1580 t.Errorf("hijack not allowed")
1581 return
1582 }
1583 clientConn, _, err := hijacker.Hijack()
1584 if err != nil {
1585 t.Errorf("hijacking failed")
1586 return
1587 }
1588 res := &Response{
1589 StatusCode: StatusOK,
1590 Proto: "HTTP/1.1",
1591 ProtoMajor: 1,
1592 ProtoMinor: 1,
1593 Header: make(Header),
1594 }
1595
1596 targetConn, err := net.Dial("tcp", r.URL.Host)
1597 if err != nil {
1598 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1599 return
1600 }
1601
1602 if err := res.Write(clientConn); err != nil {
1603 t.Errorf("Writing 200 OK failed: %v", err)
1604 return
1605 }
1606
1607 go io.Copy(targetConn, clientConn)
1608 go func() {
1609 io.Copy(clientConn, targetConn)
1610 targetConn.Close()
1611 }()
1612 }
1613 })
1614 ts := newClientServerTest(t, https1Mode, h1).ts
1615 proxy := newClientServerTest(t, https1Mode, h2).ts
1616
1617 pu, err := url.Parse(proxy.URL)
1618 if err != nil {
1619 t.Fatal(err)
1620 }
1621
1622 c := proxy.Client()
1623
1624 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1625 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1626 if proxyURL.String() != pu.String() {
1627 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1628 }
1629
1630 if "https://"+connectReq.URL.String() != ts.URL {
1631 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1632 }
1633 return tcase.err
1634 }
1635 if _, err := c.Head(ts.URL); err != nil {
1636 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1637 t.Errorf("got %v, want %v", err, tcase.err)
1638 }
1639 }
1640 }
1641 }
1642
1643
1644
1645 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1646 setParallel(t)
1647 defer afterTest(t)
1648
1649 ctx, cancel := context.WithCancel(context.Background())
1650 defer cancel()
1651
1652 ln := newLocalListener(t)
1653 defer ln.Close()
1654 listenerDone := make(chan struct{})
1655 go func() {
1656 defer close(listenerDone)
1657 c, err := ln.Accept()
1658 if err != nil {
1659 t.Errorf("Accept: %v", err)
1660 return
1661 }
1662 defer c.Close()
1663
1664 br := bufio.NewReader(c)
1665 cr, err := ReadRequest(br)
1666 if err != nil {
1667 t.Errorf("proxy server failed to read CONNECT request")
1668 return
1669 }
1670 if cr.Method != "CONNECT" {
1671 t.Errorf("unexpected method %q", cr.Method)
1672 return
1673 }
1674
1675
1676
1677
1678 cancel()
1679 var buf [1]byte
1680 _, err = br.Read(buf[:])
1681 if err != io.EOF {
1682 t.Errorf("proxy server Read err = %v; want EOF", err)
1683 }
1684 return
1685 }()
1686
1687 c := &Client{
1688 Transport: &Transport{
1689 Proxy: func(*Request) (*url.URL, error) {
1690 return url.Parse("http://" + ln.Addr().String())
1691 },
1692 },
1693 }
1694 req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
1695 if err != nil {
1696 t.Fatal(err)
1697 }
1698 _, err = c.Do(req)
1699 if err == nil {
1700 t.Errorf("unexpected Get success")
1701 }
1702
1703
1704
1705
1706 <-listenerDone
1707 }
1708
1709
1710 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1711 defer afterTest(t)
1712
1713 var errDial = errors.New("some dial error")
1714
1715 tr := &Transport{
1716 Proxy: func(*Request) (*url.URL, error) {
1717 return url.Parse("http://proxy.fake.tld/")
1718 },
1719 Dial: func(string, string) (net.Conn, error) {
1720 return nil, errDial
1721 },
1722 }
1723 defer tr.CloseIdleConnections()
1724
1725 c := &Client{Transport: tr}
1726 req, _ := NewRequest("GET", "http://fake.tld", nil)
1727 res, err := c.Do(req)
1728 if err == nil {
1729 res.Body.Close()
1730 t.Fatal("wanted a non-nil error")
1731 }
1732
1733 uerr, ok := err.(*url.Error)
1734 if !ok {
1735 t.Fatalf("got %T, want *url.Error", err)
1736 }
1737 oe, ok := uerr.Err.(*net.OpError)
1738 if !ok {
1739 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1740 }
1741 want := &net.OpError{
1742 Op: "proxyconnect",
1743 Net: "tcp",
1744 Err: errDial,
1745 }
1746 if !reflect.DeepEqual(oe, want) {
1747 t.Errorf("Got error %#v; want %#v", oe, want)
1748 }
1749 }
1750
1751
1752
1753
1754
1755 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1756 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1757 }
1758 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1759 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1760 defer proxy.Close()
1761 c := proxy.Client()
1762
1763 tr := c.Transport.(*Transport)
1764 tr.Proxy = func(*Request) (*url.URL, error) {
1765 u, _ := url.Parse(proxy.URL)
1766 u.User = url.UserPassword("aladdin", "opensesame")
1767 return u, nil
1768 }
1769 h := tr.ProxyConnectHeader
1770 if h == nil {
1771 h = make(Header)
1772 }
1773 tr.ProxyConnectHeader = h.Clone()
1774
1775 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1776 if err != nil {
1777 t.Fatal(err)
1778 }
1779 _, err = c.Do(req)
1780 if err == nil {
1781 t.Errorf("unexpected Get success")
1782 }
1783
1784 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1785 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1786 }
1787 }
1788
1789
1790
1791
1792
1793 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1794 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1795 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1796 w.Header().Set("Content-Encoding", "gzip")
1797 w.Write(rgz)
1798 })).ts
1799
1800 c := ts.Client()
1801 res, err := c.Get(ts.URL)
1802 if err != nil {
1803 t.Fatal(err)
1804 }
1805 body, err := io.ReadAll(res.Body)
1806 if err != nil {
1807 t.Fatal(err)
1808 }
1809 if !bytes.Equal(body, rgz) {
1810 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1811 body, rgz)
1812 }
1813 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1814 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1815 }
1816 }
1817
1818
1819
1820 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1821 func testTransportGzipShort(t *testing.T, mode testMode) {
1822 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1823 w.Header().Set("Content-Encoding", "gzip")
1824 w.Write([]byte{0x1f, 0x8b})
1825 })).ts
1826
1827 c := ts.Client()
1828 res, err := c.Get(ts.URL)
1829 if err != nil {
1830 t.Fatal(err)
1831 }
1832 defer res.Body.Close()
1833 _, err = io.ReadAll(res.Body)
1834 if err == nil {
1835 t.Fatal("Expect an error from reading a body.")
1836 }
1837 if err != io.ErrUnexpectedEOF {
1838 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1839 }
1840 }
1841
1842
1843 func waitNumGoroutine(nmax int) int {
1844 nfinal := runtime.NumGoroutine()
1845 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1846 time.Sleep(50 * time.Millisecond)
1847 runtime.GC()
1848 nfinal = runtime.NumGoroutine()
1849 }
1850 return nfinal
1851 }
1852
1853
1854 func TestTransportPersistConnLeak(t *testing.T) {
1855 run(t, testTransportPersistConnLeak, testNotParallel)
1856 }
1857 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1858 if mode == http2Mode {
1859 t.Skip("flaky in HTTP/2")
1860 }
1861
1862
1863 const numReq = 25
1864 gotReqCh := make(chan bool, numReq)
1865 unblockCh := make(chan bool, numReq)
1866 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1867 gotReqCh <- true
1868 <-unblockCh
1869 w.Header().Set("Content-Length", "0")
1870 w.WriteHeader(204)
1871 })).ts
1872 c := ts.Client()
1873 tr := c.Transport.(*Transport)
1874
1875 n0 := runtime.NumGoroutine()
1876
1877 didReqCh := make(chan bool, numReq)
1878 failed := make(chan bool, numReq)
1879 for i := 0; i < numReq; i++ {
1880 go func() {
1881 res, err := c.Get(ts.URL)
1882 didReqCh <- true
1883 if err != nil {
1884 t.Logf("client fetch error: %v", err)
1885 failed <- true
1886 return
1887 }
1888 res.Body.Close()
1889 }()
1890 }
1891
1892
1893 for i := 0; i < numReq; i++ {
1894 select {
1895 case <-gotReqCh:
1896
1897 case <-failed:
1898
1899
1900 }
1901 }
1902
1903 nhigh := runtime.NumGoroutine()
1904
1905
1906 close(unblockCh)
1907
1908
1909 for i := 0; i < numReq; i++ {
1910 <-didReqCh
1911 }
1912
1913 tr.CloseIdleConnections()
1914 nfinal := waitNumGoroutine(n0 + 5)
1915
1916 growth := nfinal - n0
1917
1918
1919
1920 if int(growth) > 5 {
1921 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1922 t.Error("too many new goroutines")
1923 }
1924 }
1925
1926
1927
1928 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1929 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1930 }
1931 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1932 if mode == http2Mode {
1933 t.Skip("flaky in HTTP/2")
1934 }
1935
1936
1937 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1938 })).ts
1939 c := ts.Client()
1940 tr := c.Transport.(*Transport)
1941
1942 n0 := runtime.NumGoroutine()
1943 body := []byte("Hello")
1944 for i := 0; i < 20; i++ {
1945 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1946 if err != nil {
1947 t.Fatal(err)
1948 }
1949 req.ContentLength = int64(len(body) - 2)
1950 _, err = c.Do(req)
1951 if err == nil {
1952 t.Fatal("Expect an error from writing too long of a body.")
1953 }
1954 }
1955 nhigh := runtime.NumGoroutine()
1956 tr.CloseIdleConnections()
1957 nfinal := waitNumGoroutine(n0 + 5)
1958
1959 growth := nfinal - n0
1960
1961
1962
1963 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1964 if int(growth) > 5 {
1965 t.Error("too many new goroutines")
1966 }
1967 }
1968
1969
1970 type countedConn struct {
1971 net.Conn
1972 }
1973
1974
1975 type countingDialer struct {
1976 dialer net.Dialer
1977 mu sync.Mutex
1978 total, live int64
1979 }
1980
1981 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1982 conn, err := d.dialer.DialContext(ctx, network, address)
1983 if err != nil {
1984 return nil, err
1985 }
1986
1987 counted := new(countedConn)
1988 counted.Conn = conn
1989
1990 d.mu.Lock()
1991 defer d.mu.Unlock()
1992 d.total++
1993 d.live++
1994
1995 runtime.SetFinalizer(counted, d.decrement)
1996 return counted, nil
1997 }
1998
1999 func (d *countingDialer) decrement(*countedConn) {
2000 d.mu.Lock()
2001 defer d.mu.Unlock()
2002 d.live--
2003 }
2004
2005 func (d *countingDialer) Read() (total, live int64) {
2006 d.mu.Lock()
2007 defer d.mu.Unlock()
2008 return d.total, d.live
2009 }
2010
2011 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2012 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2013 }
2014 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2015 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2016
2017 conn, _, err := w.(Hijacker).Hijack()
2018 if err != nil {
2019 t.Errorf("Hijack failed unexpectedly: %v", err)
2020 return
2021 }
2022 conn.Close()
2023 })).ts
2024
2025 var d countingDialer
2026 c := ts.Client()
2027 c.Transport.(*Transport).DialContext = d.DialContext
2028
2029 body := []byte("Hello")
2030 for i := 0; ; i++ {
2031 total, live := d.Read()
2032 if live < total {
2033 break
2034 }
2035 if i >= 1<<12 {
2036 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2037 }
2038
2039 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2040 if err != nil {
2041 t.Fatal(err)
2042 }
2043 _, err = c.Do(req)
2044 if err == nil {
2045 t.Fatal("expected broken connection")
2046 }
2047
2048 runtime.GC()
2049 }
2050 }
2051
2052 type countedContext struct {
2053 context.Context
2054 }
2055
2056 type contextCounter struct {
2057 mu sync.Mutex
2058 live int64
2059 }
2060
2061 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2062 counted := new(countedContext)
2063 counted.Context = ctx
2064 cc.mu.Lock()
2065 defer cc.mu.Unlock()
2066 cc.live++
2067 runtime.SetFinalizer(counted, cc.decrement)
2068 return counted
2069 }
2070
2071 func (cc *contextCounter) decrement(*countedContext) {
2072 cc.mu.Lock()
2073 defer cc.mu.Unlock()
2074 cc.live--
2075 }
2076
2077 func (cc *contextCounter) Read() (live int64) {
2078 cc.mu.Lock()
2079 defer cc.mu.Unlock()
2080 return cc.live
2081 }
2082
2083 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2084 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2085 }
2086 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2087 if mode == http2Mode {
2088 t.Skip("https://go.dev/issue/56021")
2089 }
2090
2091 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2092 runtime.Gosched()
2093 w.WriteHeader(StatusOK)
2094 })).ts
2095
2096 c := ts.Client()
2097 c.Transport.(*Transport).MaxConnsPerHost = 1
2098
2099 ctx := context.Background()
2100 body := []byte("Hello")
2101 doPosts := func(cc *contextCounter) {
2102 var wg sync.WaitGroup
2103 for n := 64; n > 0; n-- {
2104 wg.Add(1)
2105 go func() {
2106 defer wg.Done()
2107
2108 ctx := cc.Track(ctx)
2109 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2110 if err != nil {
2111 t.Error(err)
2112 }
2113
2114 _, err = c.Do(req.WithContext(ctx))
2115 if err != nil {
2116 t.Errorf("Do failed with error: %v", err)
2117 }
2118 }()
2119 }
2120 wg.Wait()
2121 }
2122
2123 var initialCC contextCounter
2124 doPosts(&initialCC)
2125
2126
2127
2128
2129 var flushCC contextCounter
2130 for i := 0; ; i++ {
2131 live := initialCC.Read()
2132 if live == 0 {
2133 break
2134 }
2135 if i >= 100 {
2136 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2137 }
2138 doPosts(&flushCC)
2139 runtime.GC()
2140 }
2141 }
2142
2143
2144 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2145 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2146 var tr *Transport
2147
2148 unblockCh := make(chan bool, 1)
2149 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2150 <-unblockCh
2151 tr.CloseIdleConnections()
2152 })).ts
2153 c := ts.Client()
2154 tr = c.Transport.(*Transport)
2155
2156 didreq := make(chan bool)
2157 go func() {
2158 res, err := c.Get(ts.URL)
2159 if err != nil {
2160 t.Error(err)
2161 } else {
2162 res.Body.Close()
2163 }
2164 didreq <- true
2165 }()
2166 unblockCh <- true
2167 <-didreq
2168 }
2169
2170
2171
2172
2173
2174 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2175 func testIssue3644(t *testing.T, mode testMode) {
2176 const numFoos = 5000
2177 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2178 w.Header().Set("Connection", "close")
2179 for i := 0; i < numFoos; i++ {
2180 w.Write([]byte("foo "))
2181 }
2182 })).ts
2183 c := ts.Client()
2184 res, err := c.Get(ts.URL)
2185 if err != nil {
2186 t.Fatal(err)
2187 }
2188 defer res.Body.Close()
2189 bs, err := io.ReadAll(res.Body)
2190 if err != nil {
2191 t.Fatal(err)
2192 }
2193 if len(bs) != numFoos*len("foo ") {
2194 t.Errorf("unexpected response length")
2195 }
2196 }
2197
2198
2199
2200 func TestIssue3595(t *testing.T) {
2201
2202 run(t, testIssue3595, testNotParallel)
2203 }
2204 func testIssue3595(t *testing.T, mode testMode) {
2205 runTimeSensitiveTest(t, []time.Duration{
2206 1 * time.Millisecond,
2207 5 * time.Millisecond,
2208 10 * time.Millisecond,
2209 50 * time.Millisecond,
2210 100 * time.Millisecond,
2211 500 * time.Millisecond,
2212 time.Second,
2213 5 * time.Second,
2214 }, func(t *testing.T, timeout time.Duration) error {
2215 SetRSTAvoidanceDelay(t, timeout)
2216 t.Logf("set RST avoidance delay to %v", timeout)
2217
2218 const deniedMsg = "sorry, denied."
2219 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2220 Error(w, deniedMsg, StatusUnauthorized)
2221 }))
2222
2223
2224 defer cst.close()
2225 ts := cst.ts
2226 c := ts.Client()
2227
2228 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2229 if err != nil {
2230 return fmt.Errorf("Post: %v", err)
2231 }
2232 got, err := io.ReadAll(res.Body)
2233 if err != nil {
2234 return fmt.Errorf("Body ReadAll: %v", err)
2235 }
2236 t.Logf("server response:\n%s", got)
2237 if !strings.Contains(string(got), deniedMsg) {
2238
2239
2240 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2241 }
2242 return nil
2243 })
2244 }
2245
2246
2247
2248 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2249 func testChunkedNoContent(t *testing.T, mode testMode) {
2250 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2251 w.WriteHeader(StatusNoContent)
2252 })).ts
2253
2254 c := ts.Client()
2255 for _, closeBody := range []bool{true, false} {
2256 const n = 4
2257 for i := 1; i <= n; i++ {
2258 res, err := c.Get(ts.URL)
2259 if err != nil {
2260 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2261 } else {
2262 if closeBody {
2263 res.Body.Close()
2264 }
2265 }
2266 }
2267 }
2268 }
2269
2270 func TestTransportConcurrency(t *testing.T) {
2271 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2272 }
2273 func testTransportConcurrency(t *testing.T, mode testMode) {
2274
2275 maxProcs, numReqs := 16, 500
2276 if testing.Short() {
2277 maxProcs, numReqs = 4, 50
2278 }
2279 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2280 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2281 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2282 })).ts
2283
2284 var wg sync.WaitGroup
2285 wg.Add(numReqs)
2286
2287
2288
2289
2290
2291
2292
2293 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2294 defer SetPendingDialHooks(nil, nil)
2295
2296 c := ts.Client()
2297 reqs := make(chan string)
2298 defer close(reqs)
2299
2300 for i := 0; i < maxProcs*2; i++ {
2301 go func() {
2302 for req := range reqs {
2303 res, err := c.Get(ts.URL + "/?echo=" + req)
2304 if err != nil {
2305 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2306
2307
2308 t.Logf("error on req %s: %v", req, err)
2309 t.Logf("(see https://go.dev/issue/52168)")
2310 } else {
2311 t.Errorf("error on req %s: %v", req, err)
2312 }
2313 wg.Done()
2314 continue
2315 }
2316 all, err := io.ReadAll(res.Body)
2317 if err != nil {
2318 t.Errorf("read error on req %s: %v", req, err)
2319 } else if string(all) != req {
2320 t.Errorf("body of req %s = %q; want %q", req, all, req)
2321 }
2322 res.Body.Close()
2323 wg.Done()
2324 }
2325 }()
2326 }
2327 for i := 0; i < numReqs; i++ {
2328 reqs <- fmt.Sprintf("request-%d", i)
2329 }
2330 wg.Wait()
2331 }
2332
2333 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2334 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2335 mux := NewServeMux()
2336 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2337 io.Copy(w, neverEnding('a'))
2338 })
2339 ts := newClientServerTest(t, mode, mux).ts
2340
2341 connc := make(chan net.Conn, 1)
2342 c := ts.Client()
2343 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2344 conn, err := net.Dial(n, addr)
2345 if err != nil {
2346 return nil, err
2347 }
2348 select {
2349 case connc <- conn:
2350 default:
2351 }
2352 return conn, nil
2353 }
2354
2355 res, err := c.Get(ts.URL + "/get")
2356 if err != nil {
2357 t.Fatalf("Error issuing GET: %v", err)
2358 }
2359 defer res.Body.Close()
2360
2361 conn := <-connc
2362 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2363 _, err = io.Copy(io.Discard, res.Body)
2364 if err == nil {
2365 t.Errorf("Unexpected successful copy")
2366 }
2367 }
2368
2369 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2370 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2371 }
2372 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2373 const debug = false
2374 mux := NewServeMux()
2375 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2376 io.Copy(w, neverEnding('a'))
2377 })
2378 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2379 defer r.Body.Close()
2380 io.Copy(io.Discard, r.Body)
2381 })
2382 ts := newClientServerTest(t, mode, mux).ts
2383 timeout := 100 * time.Millisecond
2384
2385 c := ts.Client()
2386 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2387 conn, err := net.Dial(n, addr)
2388 if err != nil {
2389 return nil, err
2390 }
2391 conn.SetDeadline(time.Now().Add(timeout))
2392 if debug {
2393 conn = NewLoggingConn("client", conn)
2394 }
2395 return conn, nil
2396 }
2397
2398 getFailed := false
2399 nRuns := 5
2400 if testing.Short() {
2401 nRuns = 1
2402 }
2403 for i := 0; i < nRuns; i++ {
2404 if debug {
2405 println("run", i+1, "of", nRuns)
2406 }
2407 sres, err := c.Get(ts.URL + "/get")
2408 if err != nil {
2409 if !getFailed {
2410
2411 getFailed = true
2412 t.Logf("increasing timeout")
2413 i--
2414 timeout *= 10
2415 continue
2416 }
2417 t.Errorf("Error issuing GET: %v", err)
2418 break
2419 }
2420 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2421 _, err = c.Do(req)
2422 if err == nil {
2423 sres.Body.Close()
2424 t.Errorf("Unexpected successful PUT")
2425 break
2426 }
2427 sres.Body.Close()
2428 }
2429 if debug {
2430 println("tests complete; waiting for handlers to finish")
2431 }
2432 ts.Close()
2433 }
2434
2435 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2436 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2437 if testing.Short() {
2438 t.Skip("skipping timeout test in -short mode")
2439 }
2440
2441 timeout := 2 * time.Millisecond
2442 retry := true
2443 for retry && !t.Failed() {
2444 var srvWG sync.WaitGroup
2445 inHandler := make(chan bool, 1)
2446 mux := NewServeMux()
2447 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2448 inHandler <- true
2449 srvWG.Done()
2450 })
2451 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2452 inHandler <- true
2453 <-r.Context().Done()
2454 srvWG.Done()
2455 })
2456 ts := newClientServerTest(t, mode, mux).ts
2457
2458 c := ts.Client()
2459 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2460
2461 retry = false
2462 srvWG.Add(3)
2463 tests := []struct {
2464 path string
2465 wantTimeout bool
2466 }{
2467 {path: "/fast"},
2468 {path: "/slow", wantTimeout: true},
2469 {path: "/fast"},
2470 }
2471 for i, tt := range tests {
2472 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2473 req = req.WithT(t)
2474 res, err := c.Do(req)
2475 <-inHandler
2476 if err != nil {
2477 uerr, ok := err.(*url.Error)
2478 if !ok {
2479 t.Errorf("error is not a url.Error; got: %#v", err)
2480 continue
2481 }
2482 nerr, ok := uerr.Err.(net.Error)
2483 if !ok {
2484 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2485 continue
2486 }
2487 if !nerr.Timeout() {
2488 t.Errorf("want timeout error; got: %q", nerr)
2489 continue
2490 }
2491 if !tt.wantTimeout {
2492 if !retry {
2493
2494 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2495 timeout *= 2
2496 retry = true
2497 }
2498 }
2499 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2500 t.Errorf("%d. unexpected error: %v", i, err)
2501 }
2502 continue
2503 }
2504 if tt.wantTimeout {
2505 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2506 continue
2507 }
2508 if res.StatusCode != 200 {
2509 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2510 }
2511 }
2512
2513 srvWG.Wait()
2514 ts.Close()
2515 }
2516 }
2517
2518 func TestTransportCancelRequest(t *testing.T) {
2519 run(t, testTransportCancelRequest, []testMode{http1Mode})
2520 }
2521 func testTransportCancelRequest(t *testing.T, mode testMode) {
2522 if testing.Short() {
2523 t.Skip("skipping test in -short mode")
2524 }
2525
2526 const msg = "Hello"
2527 unblockc := make(chan bool)
2528 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2529 io.WriteString(w, msg)
2530 w.(Flusher).Flush()
2531 <-unblockc
2532 })).ts
2533 defer close(unblockc)
2534
2535 c := ts.Client()
2536 tr := c.Transport.(*Transport)
2537
2538 req, _ := NewRequest("GET", ts.URL, nil)
2539 res, err := c.Do(req)
2540 if err != nil {
2541 t.Fatal(err)
2542 }
2543 body := make([]byte, len(msg))
2544 n, _ := io.ReadFull(res.Body, body)
2545 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2546 t.Errorf("Body = %q; want %q", body[:n], msg)
2547 }
2548 tr.CancelRequest(req)
2549
2550 tail, err := io.ReadAll(res.Body)
2551 res.Body.Close()
2552 if err != ExportErrRequestCanceled {
2553 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2554 } else if len(tail) > 0 {
2555 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2556 }
2557
2558
2559
2560 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2561 n := tr.NumPendingRequestsForTesting()
2562 if n > 0 {
2563 if d > 0 {
2564 t.Logf("pending requests = %d after %v (want 0)", n, d)
2565 }
2566 return false
2567 }
2568 return true
2569 })
2570 }
2571
2572 func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
2573 if testing.Short() {
2574 t.Skip("skipping test in -short mode")
2575 }
2576 unblockc := make(chan bool)
2577 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2578 <-unblockc
2579 })).ts
2580 defer close(unblockc)
2581
2582 c := ts.Client()
2583 tr := c.Transport.(*Transport)
2584
2585 donec := make(chan bool)
2586 req, _ := NewRequest("GET", ts.URL, body)
2587 go func() {
2588 defer close(donec)
2589 c.Do(req)
2590 }()
2591
2592 unblockc <- true
2593 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2594 tr.CancelRequest(req)
2595 select {
2596 case <-donec:
2597 return true
2598 default:
2599 if d > 0 {
2600 t.Logf("Do of canceled request has not returned after %v", d)
2601 }
2602 return false
2603 }
2604 })
2605 }
2606
2607 func TestTransportCancelRequestInDo(t *testing.T) {
2608 run(t, func(t *testing.T, mode testMode) {
2609 testTransportCancelRequestInDo(t, mode, nil)
2610 }, []testMode{http1Mode})
2611 }
2612
2613 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2614 run(t, func(t *testing.T, mode testMode) {
2615 testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
2616 }, []testMode{http1Mode})
2617 }
2618
2619 func TestTransportCancelRequestInDial(t *testing.T) {
2620 defer afterTest(t)
2621 if testing.Short() {
2622 t.Skip("skipping test in -short mode")
2623 }
2624 var logbuf strings.Builder
2625 eventLog := log.New(&logbuf, "", 0)
2626
2627 unblockDial := make(chan bool)
2628 defer close(unblockDial)
2629
2630 inDial := make(chan bool)
2631 tr := &Transport{
2632 Dial: func(network, addr string) (net.Conn, error) {
2633 eventLog.Println("dial: blocking")
2634 if !<-inDial {
2635 return nil, errors.New("main Test goroutine exited")
2636 }
2637 <-unblockDial
2638 return nil, errors.New("nope")
2639 },
2640 }
2641 cl := &Client{Transport: tr}
2642 gotres := make(chan bool)
2643 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2644 go func() {
2645 _, err := cl.Do(req)
2646 eventLog.Printf("Get = %v", err)
2647 gotres <- true
2648 }()
2649
2650 inDial <- true
2651
2652 eventLog.Printf("canceling")
2653 tr.CancelRequest(req)
2654 tr.CancelRequest(req)
2655
2656 if d, ok := t.Deadline(); ok {
2657
2658
2659 timeout := time.Until(d) * 19 / 20
2660 timer := time.AfterFunc(timeout, func() {
2661 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2662 })
2663 defer timer.Stop()
2664 }
2665 <-gotres
2666
2667 got := logbuf.String()
2668 want := `dial: blocking
2669 canceling
2670 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
2671 `
2672 if got != want {
2673 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2674 }
2675 }
2676
2677 func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
2678 func testCancelRequestWithChannel(t *testing.T, mode testMode) {
2679 if testing.Short() {
2680 t.Skip("skipping test in -short mode")
2681 }
2682
2683 const msg = "Hello"
2684 unblockc := make(chan struct{})
2685 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2686 io.WriteString(w, msg)
2687 w.(Flusher).Flush()
2688 <-unblockc
2689 })).ts
2690 defer close(unblockc)
2691
2692 c := ts.Client()
2693 tr := c.Transport.(*Transport)
2694
2695 req, _ := NewRequest("GET", ts.URL, nil)
2696 cancel := make(chan struct{})
2697 req.Cancel = cancel
2698
2699 res, err := c.Do(req)
2700 if err != nil {
2701 t.Fatal(err)
2702 }
2703 body := make([]byte, len(msg))
2704 n, _ := io.ReadFull(res.Body, body)
2705 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2706 t.Errorf("Body = %q; want %q", body[:n], msg)
2707 }
2708 close(cancel)
2709
2710 tail, err := io.ReadAll(res.Body)
2711 res.Body.Close()
2712 if err != ExportErrRequestCanceled {
2713 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2714 } else if len(tail) > 0 {
2715 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2716 }
2717
2718
2719
2720 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2721 n := tr.NumPendingRequestsForTesting()
2722 if n > 0 {
2723 if d > 0 {
2724 t.Logf("pending requests = %d after %v (want 0)", n, d)
2725 }
2726 return false
2727 }
2728 return true
2729 })
2730 }
2731
2732
2733 func TestCancelRequestWithBodyWithChannel(t *testing.T) {
2734 run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
2735 }
2736 func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
2737 if testing.Short() {
2738 t.Skip("skipping test in -short mode")
2739 }
2740
2741 const msg = "Hello"
2742 unblockc := make(chan struct{})
2743 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2744 io.WriteString(w, msg)
2745 w.(Flusher).Flush()
2746 <-unblockc
2747 })).ts
2748 defer close(unblockc)
2749
2750 c := ts.Client()
2751 tr := c.Transport.(*Transport)
2752
2753 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2754 cancel := make(chan struct{})
2755 req.Cancel = cancel
2756
2757 res, err := c.Do(req)
2758 if err != nil {
2759 t.Fatal(err)
2760 }
2761 body := make([]byte, len(msg))
2762 n, _ := io.ReadFull(res.Body, body)
2763 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2764 t.Errorf("Body = %q; want %q", body[:n], msg)
2765 }
2766 close(cancel)
2767
2768 tail, err := io.ReadAll(res.Body)
2769 res.Body.Close()
2770 if err != ExportErrRequestCanceled {
2771 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2772 } else if len(tail) > 0 {
2773 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2774 }
2775
2776
2777
2778 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2779 n := tr.NumPendingRequestsForTesting()
2780 if n > 0 {
2781 if d > 0 {
2782 t.Logf("pending requests = %d after %v (want 0)", n, d)
2783 }
2784 return false
2785 }
2786 return true
2787 })
2788 }
2789
2790 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
2791 run(t, func(t *testing.T, mode testMode) {
2792 testCancelRequestWithChannelBeforeDo(t, mode, false)
2793 })
2794 }
2795 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
2796 run(t, func(t *testing.T, mode testMode) {
2797 testCancelRequestWithChannelBeforeDo(t, mode, true)
2798 })
2799 }
2800 func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
2801 unblockc := make(chan bool)
2802 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2803 <-unblockc
2804 })).ts
2805 defer close(unblockc)
2806
2807 c := ts.Client()
2808
2809 req, _ := NewRequest("GET", ts.URL, nil)
2810 if withCtx {
2811 ctx, cancel := context.WithCancel(context.Background())
2812 cancel()
2813 req = req.WithContext(ctx)
2814 } else {
2815 ch := make(chan struct{})
2816 req.Cancel = ch
2817 close(ch)
2818 }
2819
2820 _, err := c.Do(req)
2821 if ue, ok := err.(*url.Error); ok {
2822 err = ue.Err
2823 }
2824 if withCtx {
2825 if err != context.Canceled {
2826 t.Errorf("Do error = %v; want %v", err, context.Canceled)
2827 }
2828 } else {
2829 if err == nil || !strings.Contains(err.Error(), "canceled") {
2830 t.Errorf("Do error = %v; want cancellation", err)
2831 }
2832 }
2833 }
2834
2835
2836 func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
2837 defer afterTest(t)
2838
2839 serverConnCh := make(chan net.Conn, 1)
2840 tr := &Transport{
2841 Dial: func(network, addr string) (net.Conn, error) {
2842 cc, sc := net.Pipe()
2843 serverConnCh <- sc
2844 return cc, nil
2845 },
2846 }
2847 defer tr.CloseIdleConnections()
2848 errc := make(chan error, 1)
2849 req, _ := NewRequest("GET", "http://example.com/", nil)
2850 go func() {
2851 _, err := tr.RoundTrip(req)
2852 errc <- err
2853 }()
2854
2855 sc := <-serverConnCh
2856 verb := make([]byte, 3)
2857 if _, err := io.ReadFull(sc, verb); err != nil {
2858 t.Errorf("Error reading HTTP verb from server: %v", err)
2859 }
2860 if string(verb) != "GET" {
2861 t.Errorf("server received %q; want GET", verb)
2862 }
2863 defer sc.Close()
2864
2865 tr.CancelRequest(req)
2866
2867 err := <-errc
2868 if err == nil {
2869 t.Fatalf("unexpected success from RoundTrip")
2870 }
2871 if err != ExportErrRequestCanceled {
2872 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
2873 }
2874 }
2875
2876
2877
2878
2879 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2880 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2881 writeErr := make(chan error, 1)
2882 msg := []byte("young\n")
2883 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2884 for {
2885 _, err := w.Write(msg)
2886 if err != nil {
2887 writeErr <- err
2888 return
2889 }
2890 w.(Flusher).Flush()
2891 }
2892 })).ts
2893
2894 c := ts.Client()
2895 tr := c.Transport.(*Transport)
2896
2897 req, _ := NewRequest("GET", ts.URL, nil)
2898 defer tr.CancelRequest(req)
2899
2900 res, err := c.Do(req)
2901 if err != nil {
2902 t.Fatal(err)
2903 }
2904
2905 const repeats = 3
2906 buf := make([]byte, len(msg)*repeats)
2907 want := bytes.Repeat(msg, repeats)
2908
2909 _, err = io.ReadFull(res.Body, buf)
2910 if err != nil {
2911 t.Fatal(err)
2912 }
2913 if !bytes.Equal(buf, want) {
2914 t.Fatalf("read %q; want %q", buf, want)
2915 }
2916
2917 if err := res.Body.Close(); err != nil {
2918 t.Errorf("Close = %v", err)
2919 }
2920
2921 if err := <-writeErr; err == nil {
2922 t.Errorf("expected non-nil write error")
2923 }
2924 }
2925
2926 type fooProto struct{}
2927
2928 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2929 res := &Response{
2930 Status: "200 OK",
2931 StatusCode: 200,
2932 Header: make(Header),
2933 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2934 }
2935 return res, nil
2936 }
2937
2938 func TestTransportAltProto(t *testing.T) {
2939 defer afterTest(t)
2940 tr := &Transport{}
2941 c := &Client{Transport: tr}
2942 tr.RegisterProtocol("foo", fooProto{})
2943 res, err := c.Get("foo://bar.com/path")
2944 if err != nil {
2945 t.Fatal(err)
2946 }
2947 bodyb, err := io.ReadAll(res.Body)
2948 if err != nil {
2949 t.Fatal(err)
2950 }
2951 body := string(bodyb)
2952 if e := "You wanted foo://bar.com/path"; body != e {
2953 t.Errorf("got response %q, want %q", body, e)
2954 }
2955 }
2956
2957 func TestTransportNoHost(t *testing.T) {
2958 defer afterTest(t)
2959 tr := &Transport{}
2960 _, err := tr.RoundTrip(&Request{
2961 Header: make(Header),
2962 URL: &url.URL{
2963 Scheme: "http",
2964 },
2965 })
2966 want := "http: no Host in request URL"
2967 if got := fmt.Sprint(err); got != want {
2968 t.Errorf("error = %v; want %q", err, want)
2969 }
2970 }
2971
2972
2973 func TestTransportEmptyMethod(t *testing.T) {
2974 req, _ := NewRequest("GET", "http://foo.com/", nil)
2975 req.Method = ""
2976 got, err := httputil.DumpRequestOut(req, false)
2977 if err != nil {
2978 t.Fatal(err)
2979 }
2980 if !strings.Contains(string(got), "GET ") {
2981 t.Fatalf("expected substring 'GET '; got: %s", got)
2982 }
2983 }
2984
2985 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
2986 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
2987 mux := NewServeMux()
2988 fooGate := make(chan bool, 1)
2989 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
2990 w.Header().Set("foo-ipport", r.RemoteAddr)
2991 w.(Flusher).Flush()
2992 <-fooGate
2993 })
2994 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
2995 w.Header().Set("bar-ipport", r.RemoteAddr)
2996 })
2997 ts := newClientServerTest(t, mode, mux).ts
2998
2999 dialGate := make(chan bool, 1)
3000 dialing := make(chan bool)
3001 c := ts.Client()
3002 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3003 for {
3004 select {
3005 case ok := <-dialGate:
3006 if !ok {
3007 return nil, errors.New("manually closed")
3008 }
3009 return net.Dial(n, addr)
3010 case dialing <- true:
3011 }
3012 }
3013 }
3014 defer close(dialGate)
3015
3016 dialGate <- true
3017 fooRes, err := c.Get(ts.URL + "/foo")
3018 if err != nil {
3019 t.Fatal(err)
3020 }
3021 fooAddr := fooRes.Header.Get("foo-ipport")
3022 if fooAddr == "" {
3023 t.Fatal("No addr on /foo request")
3024 }
3025
3026 fooDone := make(chan struct{})
3027 go func() {
3028
3029
3030
3031
3032 if mode == http2Mode {
3033
3034
3035
3036
3037 select {
3038 case <-dialing:
3039 t.Errorf("unexpected second Dial in HTTP/2 mode")
3040 case <-time.After(10 * time.Millisecond):
3041 }
3042 } else {
3043 <-dialing
3044 }
3045 fooGate <- true
3046 io.Copy(io.Discard, fooRes.Body)
3047 fooRes.Body.Close()
3048 close(fooDone)
3049 }()
3050 defer func() {
3051 <-fooDone
3052 }()
3053
3054 barRes, err := c.Get(ts.URL + "/bar")
3055 if err != nil {
3056 t.Fatal(err)
3057 }
3058 barAddr := barRes.Header.Get("bar-ipport")
3059 if barAddr != fooAddr {
3060 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3061 }
3062 barRes.Body.Close()
3063 }
3064
3065
3066 func TestTransportReading100Continue(t *testing.T) {
3067 defer afterTest(t)
3068
3069 const numReqs = 5
3070 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3071 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3072
3073 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3074 defer w.Close()
3075 defer r.Close()
3076 br := bufio.NewReader(r)
3077 n := 0
3078 for {
3079 n++
3080 req, err := ReadRequest(br)
3081 if err == io.EOF {
3082 return
3083 }
3084 if err != nil {
3085 t.Error(err)
3086 return
3087 }
3088 slurp, err := io.ReadAll(req.Body)
3089 if err != nil {
3090 t.Errorf("Server request body slurp: %v", err)
3091 return
3092 }
3093 id := req.Header.Get("Request-Id")
3094 resCode := req.Header.Get("X-Want-Response-Code")
3095 if resCode == "" {
3096 resCode = "100 Continue"
3097 if string(slurp) != reqBody(n) {
3098 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3099 }
3100 }
3101 body := fmt.Sprintf("Response number %d", n)
3102 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3103 Date: Thu, 28 Feb 2013 17:55:41 GMT
3104
3105 HTTP/1.1 200 OK
3106 Content-Type: text/html
3107 Echo-Request-Id: %s
3108 Content-Length: %d
3109
3110 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3111 w.Write(v)
3112 if id == reqID(numReqs) {
3113 return
3114 }
3115 }
3116
3117 }
3118
3119 tr := &Transport{
3120 Dial: func(n, addr string) (net.Conn, error) {
3121 sr, sw := io.Pipe()
3122 cr, cw := io.Pipe()
3123 conn := &rwTestConn{
3124 Reader: cr,
3125 Writer: sw,
3126 closeFunc: func() error {
3127 sw.Close()
3128 cw.Close()
3129 return nil
3130 },
3131 }
3132 go send100Response(cw, sr)
3133 return conn, nil
3134 },
3135 DisableKeepAlives: false,
3136 }
3137 defer tr.CloseIdleConnections()
3138 c := &Client{Transport: tr}
3139
3140 testResponse := func(req *Request, name string, wantCode int) {
3141 t.Helper()
3142 res, err := c.Do(req)
3143 if err != nil {
3144 t.Fatalf("%s: Do: %v", name, err)
3145 }
3146 if res.StatusCode != wantCode {
3147 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3148 }
3149 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3150 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3151 }
3152 _, err = io.ReadAll(res.Body)
3153 if err != nil {
3154 t.Fatalf("%s: Slurp error: %v", name, err)
3155 }
3156 }
3157
3158
3159 for i := 1; i <= numReqs; i++ {
3160 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3161 req.Header.Set("Request-Id", reqID(i))
3162 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3163 }
3164 }
3165
3166
3167
3168 func TestTransportIgnore1xxResponses(t *testing.T) {
3169 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3170 }
3171 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3172 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3173 conn, buf, _ := w.(Hijacker).Hijack()
3174 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3175 buf.Flush()
3176 conn.Close()
3177 }))
3178 cst.tr.DisableKeepAlives = true
3179
3180 var got strings.Builder
3181
3182 req, _ := NewRequest("GET", cst.ts.URL, nil)
3183 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3184 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3185 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3186 return nil
3187 },
3188 }))
3189 res, err := cst.c.Do(req)
3190 if err != nil {
3191 t.Fatal(err)
3192 }
3193 defer res.Body.Close()
3194
3195 res.Write(&got)
3196 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3197 if got.String() != want {
3198 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3199 }
3200 }
3201
3202 func TestTransportLimits1xxResponses(t *testing.T) {
3203 run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3204 }
3205 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3206 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3207 conn, buf, _ := w.(Hijacker).Hijack()
3208 for i := 0; i < 10; i++ {
3209 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3210 }
3211 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3212 buf.Flush()
3213 conn.Close()
3214 }))
3215 cst.tr.DisableKeepAlives = true
3216
3217 res, err := cst.c.Get(cst.ts.URL)
3218 if res != nil {
3219 defer res.Body.Close()
3220 }
3221 got := fmt.Sprint(err)
3222 wantSub := "too many 1xx informational responses"
3223 if !strings.Contains(got, wantSub) {
3224 t.Errorf("Get error = %v; want substring %q", err, wantSub)
3225 }
3226 }
3227
3228
3229
3230 func TestTransportTreat101Terminal(t *testing.T) {
3231 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3232 }
3233 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3234 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3235 conn, buf, _ := w.(Hijacker).Hijack()
3236 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3237 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3238 buf.Flush()
3239 conn.Close()
3240 }))
3241 res, err := cst.c.Get(cst.ts.URL)
3242 if err != nil {
3243 t.Fatal(err)
3244 }
3245 defer res.Body.Close()
3246 if res.StatusCode != StatusSwitchingProtocols {
3247 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3248 }
3249 }
3250
3251 type proxyFromEnvTest struct {
3252 req string
3253
3254 env string
3255 httpsenv string
3256 noenv string
3257 reqmeth string
3258
3259 want string
3260 wanterr error
3261 }
3262
3263 func (t proxyFromEnvTest) String() string {
3264 var buf strings.Builder
3265 space := func() {
3266 if buf.Len() > 0 {
3267 buf.WriteByte(' ')
3268 }
3269 }
3270 if t.env != "" {
3271 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3272 }
3273 if t.httpsenv != "" {
3274 space()
3275 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3276 }
3277 if t.noenv != "" {
3278 space()
3279 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3280 }
3281 if t.reqmeth != "" {
3282 space()
3283 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3284 }
3285 req := "http://example.com"
3286 if t.req != "" {
3287 req = t.req
3288 }
3289 space()
3290 fmt.Fprintf(&buf, "req=%q", req)
3291 return strings.TrimSpace(buf.String())
3292 }
3293
3294 var proxyFromEnvTests = []proxyFromEnvTest{
3295 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3296 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3297 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3298 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3299 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3300 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3301 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3302
3303
3304 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3305
3306 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3307 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3308
3309
3310
3311 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3312 want: "<nil>",
3313 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3314
3315 {want: "<nil>"},
3316
3317 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3318 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3319 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3320 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3321 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3322 }
3323
3324 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3325 t.Helper()
3326 reqURL := tt.req
3327 if reqURL == "" {
3328 reqURL = "http://example.com"
3329 }
3330 req, _ := NewRequest("GET", reqURL, nil)
3331 url, err := proxyForRequest(req)
3332 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3333 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3334 return
3335 }
3336 if got := fmt.Sprintf("%s", url); got != tt.want {
3337 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3338 }
3339 }
3340
3341 func TestProxyFromEnvironment(t *testing.T) {
3342 ResetProxyEnv()
3343 defer ResetProxyEnv()
3344 for _, tt := range proxyFromEnvTests {
3345 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3346 os.Setenv("HTTP_PROXY", tt.env)
3347 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3348 os.Setenv("NO_PROXY", tt.noenv)
3349 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3350 ResetCachedEnvironment()
3351 return ProxyFromEnvironment(req)
3352 })
3353 }
3354 }
3355
3356 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3357 ResetProxyEnv()
3358 defer ResetProxyEnv()
3359 for _, tt := range proxyFromEnvTests {
3360 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3361 os.Setenv("http_proxy", tt.env)
3362 os.Setenv("https_proxy", tt.httpsenv)
3363 os.Setenv("no_proxy", tt.noenv)
3364 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3365 ResetCachedEnvironment()
3366 return ProxyFromEnvironment(req)
3367 })
3368 }
3369 }
3370
3371 func TestIdleConnChannelLeak(t *testing.T) {
3372 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3373 }
3374 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3375
3376 var mu sync.Mutex
3377 var n int
3378
3379 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3380 mu.Lock()
3381 n++
3382 mu.Unlock()
3383 })).ts
3384
3385 const nReqs = 5
3386 didRead := make(chan bool, nReqs)
3387 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3388 defer SetReadLoopBeforeNextReadHook(nil)
3389
3390 c := ts.Client()
3391 tr := c.Transport.(*Transport)
3392 tr.Dial = func(netw, addr string) (net.Conn, error) {
3393 return net.Dial(netw, ts.Listener.Addr().String())
3394 }
3395
3396
3397 for _, disableKeep := range []bool{true, false} {
3398 tr.DisableKeepAlives = disableKeep
3399 for i := 0; i < nReqs; i++ {
3400 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3401 if err != nil {
3402 t.Fatal(err)
3403 }
3404
3405
3406
3407
3408
3409 }
3410
3411
3412
3413
3414
3415
3416
3417 for i := 0; i < nReqs; i++ {
3418 <-didRead
3419 }
3420
3421 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3422 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3423 }
3424 }
3425 }
3426
3427
3428
3429
3430 func TestTransportClosesRequestBody(t *testing.T) {
3431 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3432 }
3433 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3434 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3435 io.Copy(io.Discard, r.Body)
3436 })).ts
3437
3438 c := ts.Client()
3439
3440 closes := 0
3441
3442 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3443 if err != nil {
3444 t.Fatal(err)
3445 }
3446 res.Body.Close()
3447 if closes != 1 {
3448 t.Errorf("closes = %d; want 1", closes)
3449 }
3450 }
3451
3452 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3453 defer afterTest(t)
3454 if testing.Short() {
3455 t.Skip("skipping in short mode")
3456 }
3457 ln := newLocalListener(t)
3458 defer ln.Close()
3459 testdonec := make(chan struct{})
3460 defer close(testdonec)
3461
3462 go func() {
3463 c, err := ln.Accept()
3464 if err != nil {
3465 t.Error(err)
3466 return
3467 }
3468 <-testdonec
3469 c.Close()
3470 }()
3471
3472 tr := &Transport{
3473 Dial: func(_, _ string) (net.Conn, error) {
3474 return net.Dial("tcp", ln.Addr().String())
3475 },
3476 TLSHandshakeTimeout: 250 * time.Millisecond,
3477 }
3478 cl := &Client{Transport: tr}
3479 _, err := cl.Get("https://dummy.tld/")
3480 if err == nil {
3481 t.Error("expected error")
3482 return
3483 }
3484 ue, ok := err.(*url.Error)
3485 if !ok {
3486 t.Errorf("expected url.Error; got %#v", err)
3487 return
3488 }
3489 ne, ok := ue.Err.(net.Error)
3490 if !ok {
3491 t.Errorf("expected net.Error; got %#v", err)
3492 return
3493 }
3494 if !ne.Timeout() {
3495 t.Errorf("expected timeout error; got %v", err)
3496 }
3497 if !strings.Contains(err.Error(), "handshake timeout") {
3498 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3499 }
3500 }
3501
3502
3503 func TestTLSServerClosesConnection(t *testing.T) {
3504 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3505 }
3506 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3507 closedc := make(chan bool, 1)
3508 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3509 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3510 conn, _, _ := w.(Hijacker).Hijack()
3511 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3512 conn.Close()
3513 closedc <- true
3514 return
3515 }
3516 fmt.Fprintf(w, "hello")
3517 })).ts
3518
3519 c := ts.Client()
3520 tr := c.Transport.(*Transport)
3521
3522 var nSuccess = 0
3523 var errs []error
3524 const trials = 20
3525 for i := 0; i < trials; i++ {
3526 tr.CloseIdleConnections()
3527 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3528 if err != nil {
3529 t.Fatal(err)
3530 }
3531 <-closedc
3532 slurp, err := io.ReadAll(res.Body)
3533 if err != nil {
3534 t.Fatal(err)
3535 }
3536 if string(slurp) != "foo" {
3537 t.Errorf("Got %q, want foo", slurp)
3538 }
3539
3540
3541
3542 res, err = c.Get(ts.URL + "/")
3543 if err != nil {
3544 errs = append(errs, err)
3545 continue
3546 }
3547 slurp, err = io.ReadAll(res.Body)
3548 if err != nil {
3549 errs = append(errs, err)
3550 continue
3551 }
3552 nSuccess++
3553 }
3554 if nSuccess > 0 {
3555 t.Logf("successes = %d of %d", nSuccess, trials)
3556 } else {
3557 t.Errorf("All runs failed:")
3558 }
3559 for _, err := range errs {
3560 t.Logf(" err: %v", err)
3561 }
3562 }
3563
3564
3565
3566
3567 type byteFromChanReader chan byte
3568
3569 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3570 if len(p) == 0 {
3571 return
3572 }
3573 b, ok := <-c
3574 if !ok {
3575 return 0, io.EOF
3576 }
3577 p[0] = b
3578 return 1, nil
3579 }
3580
3581
3582
3583
3584
3585
3586
3587 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3588 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3589 }
3590 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3591 defer func(d time.Duration) {
3592 *MaxWriteWaitBeforeConnReuse = d
3593 }(*MaxWriteWaitBeforeConnReuse)
3594 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3595 var sconn struct {
3596 sync.Mutex
3597 c net.Conn
3598 }
3599 var getOkay bool
3600 var copying sync.WaitGroup
3601 closeConn := func() {
3602 sconn.Lock()
3603 defer sconn.Unlock()
3604 if sconn.c != nil {
3605 sconn.c.Close()
3606 sconn.c = nil
3607 if !getOkay {
3608 t.Logf("Closed server connection")
3609 }
3610 }
3611 }
3612 defer func() {
3613 closeConn()
3614 copying.Wait()
3615 }()
3616
3617 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3618 if r.Method == "GET" {
3619 io.WriteString(w, "bar")
3620 return
3621 }
3622 conn, _, _ := w.(Hijacker).Hijack()
3623 sconn.Lock()
3624 sconn.c = conn
3625 sconn.Unlock()
3626 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3627
3628 copying.Add(1)
3629 go func() {
3630 io.Copy(io.Discard, conn)
3631 copying.Done()
3632 }()
3633 })).ts
3634 c := ts.Client()
3635
3636 const bodySize = 256 << 10
3637 finalBit := make(byteFromChanReader, 1)
3638 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3639 req.ContentLength = bodySize
3640 res, err := c.Do(req)
3641 if err := wantBody(res, err, "foo"); err != nil {
3642 t.Errorf("POST response: %v", err)
3643 }
3644
3645 res, err = c.Get(ts.URL)
3646 if err := wantBody(res, err, "bar"); err != nil {
3647 t.Errorf("GET response: %v", err)
3648 return
3649 }
3650 getOkay = true
3651 finalBit <- 'x'
3652 close(finalBit)
3653 }
3654
3655
3656
3657 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3658 func testTransportIssue10457(t *testing.T, mode testMode) {
3659 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3660
3661
3662
3663
3664
3665 conn, _, _ := w.(Hijacker).Hijack()
3666 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3667 conn.Close()
3668 })).ts
3669 c := ts.Client()
3670
3671 res, err := c.Get(ts.URL)
3672 if err != nil {
3673 t.Fatalf("Get: %v", err)
3674 }
3675 defer res.Body.Close()
3676
3677
3678
3679
3680 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3681 t.Errorf("Foo header = %q; want %q", got, want)
3682 }
3683 }
3684
3685 type closerFunc func() error
3686
3687 func (f closerFunc) Close() error { return f() }
3688
3689 type writerFuncConn struct {
3690 net.Conn
3691 write func(p []byte) (n int, err error)
3692 }
3693
3694 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708 func TestRetryRequestsOnError(t *testing.T) {
3709 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3710 }
3711 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3712 newRequest := func(method, urlStr string, body io.Reader) *Request {
3713 req, err := NewRequest(method, urlStr, body)
3714 if err != nil {
3715 t.Fatal(err)
3716 }
3717 return req
3718 }
3719
3720 testCases := []struct {
3721 name string
3722 failureN int
3723 failureErr error
3724
3725
3726
3727 req func() *Request
3728 reqString string
3729 }{
3730 {
3731 name: "IdempotentNoBodySomeWritten",
3732
3733
3734 failureN: 1,
3735
3736 failureErr: ExportErrServerClosedIdle,
3737 req: func() *Request {
3738 return newRequest("GET", "http://fake.golang", nil)
3739 },
3740 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3741 },
3742 {
3743 name: "IdempotentGetBodySomeWritten",
3744
3745
3746 failureN: 1,
3747
3748 failureErr: ExportErrServerClosedIdle,
3749 req: func() *Request {
3750 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3751 },
3752 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3753 },
3754 {
3755 name: "NothingWrittenNoBody",
3756
3757
3758 failureN: 0,
3759 failureErr: errors.New("second write fails"),
3760 req: func() *Request {
3761 return newRequest("DELETE", "http://fake.golang", nil)
3762 },
3763 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3764 },
3765 {
3766 name: "NothingWrittenGetBody",
3767
3768
3769 failureN: 0,
3770 failureErr: errors.New("second write fails"),
3771
3772
3773 req: func() *Request {
3774 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3775 },
3776 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3777 },
3778 }
3779
3780 for _, tc := range testCases {
3781 t.Run(tc.name, func(t *testing.T) {
3782 var (
3783 mu sync.Mutex
3784 logbuf strings.Builder
3785 )
3786 logf := func(format string, args ...any) {
3787 mu.Lock()
3788 defer mu.Unlock()
3789 fmt.Fprintf(&logbuf, format, args...)
3790 logbuf.WriteByte('\n')
3791 }
3792
3793 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3794 logf("Handler")
3795 w.Header().Set("X-Status", "ok")
3796 })).ts
3797
3798 var writeNumAtomic int32
3799 c := ts.Client()
3800 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3801 logf("Dial")
3802 c, err := net.Dial(network, ts.Listener.Addr().String())
3803 if err != nil {
3804 logf("Dial error: %v", err)
3805 return nil, err
3806 }
3807 return &writerFuncConn{
3808 Conn: c,
3809 write: func(p []byte) (n int, err error) {
3810 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3811 logf("intentional write failure")
3812 return tc.failureN, tc.failureErr
3813 }
3814 logf("Write(%q)", p)
3815 return c.Write(p)
3816 },
3817 }, nil
3818 }
3819
3820 SetRoundTripRetried(func() {
3821 logf("Retried.")
3822 })
3823 defer SetRoundTripRetried(nil)
3824
3825 for i := 0; i < 3; i++ {
3826 t0 := time.Now()
3827 req := tc.req()
3828 res, err := c.Do(req)
3829 if err != nil {
3830 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3831 mu.Lock()
3832 got := logbuf.String()
3833 mu.Unlock()
3834 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3835 }
3836 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3837 }
3838 res.Body.Close()
3839 if res.Request != req {
3840 t.Errorf("Response.Request != original request; want identical Request")
3841 }
3842 }
3843
3844 mu.Lock()
3845 got := logbuf.String()
3846 mu.Unlock()
3847 want := fmt.Sprintf(`Dial
3848 Write("%s")
3849 Handler
3850 intentional write failure
3851 Retried.
3852 Dial
3853 Write("%s")
3854 Handler
3855 Write("%s")
3856 Handler
3857 `, tc.reqString, tc.reqString, tc.reqString)
3858 if got != want {
3859 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3860 }
3861 })
3862 }
3863 }
3864
3865
3866 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3867 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3868 readBody := make(chan error, 1)
3869 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3870 _, err := io.ReadAll(r.Body)
3871 readBody <- err
3872 })).ts
3873 c := ts.Client()
3874 fakeErr := errors.New("fake error")
3875 didClose := make(chan bool, 1)
3876 req, _ := NewRequest("POST", ts.URL, struct {
3877 io.Reader
3878 io.Closer
3879 }{
3880 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3881 closerFunc(func() error {
3882 select {
3883 case didClose <- true:
3884 default:
3885 }
3886 return nil
3887 }),
3888 })
3889 res, err := c.Do(req)
3890 if res != nil {
3891 defer res.Body.Close()
3892 }
3893 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3894 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3895 }
3896 if err := <-readBody; err == nil {
3897 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3898 }
3899 select {
3900 case <-didClose:
3901 default:
3902 t.Errorf("didn't see Body.Close")
3903 }
3904 }
3905
3906 func TestTransportDialTLS(t *testing.T) {
3907 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3908 }
3909 func testTransportDialTLS(t *testing.T, mode testMode) {
3910 var mu sync.Mutex
3911 var gotReq, didDial bool
3912
3913 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3914 mu.Lock()
3915 gotReq = true
3916 mu.Unlock()
3917 })).ts
3918 c := ts.Client()
3919 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3920 mu.Lock()
3921 didDial = true
3922 mu.Unlock()
3923 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3924 if err != nil {
3925 return nil, err
3926 }
3927 return c, c.Handshake()
3928 }
3929
3930 res, err := c.Get(ts.URL)
3931 if err != nil {
3932 t.Fatal(err)
3933 }
3934 res.Body.Close()
3935 mu.Lock()
3936 if !gotReq {
3937 t.Error("didn't get request")
3938 }
3939 if !didDial {
3940 t.Error("didn't use dial hook")
3941 }
3942 }
3943
3944 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
3945 func testTransportDialContext(t *testing.T, mode testMode) {
3946 var mu sync.Mutex
3947 var gotReq bool
3948 var receivedContext context.Context
3949
3950 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3951 mu.Lock()
3952 gotReq = true
3953 mu.Unlock()
3954 })).ts
3955 c := ts.Client()
3956 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3957 mu.Lock()
3958 receivedContext = ctx
3959 mu.Unlock()
3960 return net.Dial(netw, addr)
3961 }
3962
3963 req, err := NewRequest("GET", ts.URL, nil)
3964 if err != nil {
3965 t.Fatal(err)
3966 }
3967 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3968 res, err := c.Do(req.WithContext(ctx))
3969 if err != nil {
3970 t.Fatal(err)
3971 }
3972 res.Body.Close()
3973 mu.Lock()
3974 if !gotReq {
3975 t.Error("didn't get request")
3976 }
3977 if receivedContext != ctx {
3978 t.Error("didn't receive correct context")
3979 }
3980 }
3981
3982 func TestTransportDialTLSContext(t *testing.T) {
3983 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
3984 }
3985 func testTransportDialTLSContext(t *testing.T, mode testMode) {
3986 var mu sync.Mutex
3987 var gotReq bool
3988 var receivedContext context.Context
3989
3990 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3991 mu.Lock()
3992 gotReq = true
3993 mu.Unlock()
3994 })).ts
3995 c := ts.Client()
3996 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3997 mu.Lock()
3998 receivedContext = ctx
3999 mu.Unlock()
4000 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4001 if err != nil {
4002 return nil, err
4003 }
4004 return c, c.HandshakeContext(ctx)
4005 }
4006
4007 req, err := NewRequest("GET", ts.URL, nil)
4008 if err != nil {
4009 t.Fatal(err)
4010 }
4011 ctx := context.WithValue(context.Background(), "some-key", "some-value")
4012 res, err := c.Do(req.WithContext(ctx))
4013 if err != nil {
4014 t.Fatal(err)
4015 }
4016 res.Body.Close()
4017 mu.Lock()
4018 if !gotReq {
4019 t.Error("didn't get request")
4020 }
4021 if receivedContext != ctx {
4022 t.Error("didn't receive correct context")
4023 }
4024 }
4025
4026
4027
4028 func TestRoundTripReturnsProxyError(t *testing.T) {
4029 badProxy := func(*Request) (*url.URL, error) {
4030 return nil, errors.New("errorMessage")
4031 }
4032
4033 tr := &Transport{Proxy: badProxy}
4034
4035 req, _ := NewRequest("GET", "http://example.com", nil)
4036
4037 _, err := tr.RoundTrip(req)
4038
4039 if err == nil {
4040 t.Error("Expected proxy error to be returned by RoundTrip")
4041 }
4042 }
4043
4044
4045 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4046 tr := &Transport{}
4047 wantIdle := func(when string, n int) bool {
4048 got := tr.IdleConnCountForTesting("http", "example.com")
4049 if got == n {
4050 return true
4051 }
4052 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4053 return false
4054 }
4055 wantIdle("start", 0)
4056 if !tr.PutIdleTestConn("http", "example.com") {
4057 t.Fatal("put failed")
4058 }
4059 if !tr.PutIdleTestConn("http", "example.com") {
4060 t.Fatal("second put failed")
4061 }
4062 wantIdle("after put", 2)
4063 tr.CloseIdleConnections()
4064 if !tr.IsIdleForTesting() {
4065 t.Error("should be idle after CloseIdleConnections")
4066 }
4067 wantIdle("after close idle", 0)
4068 if tr.PutIdleTestConn("http", "example.com") {
4069 t.Fatal("put didn't fail")
4070 }
4071 wantIdle("after second put", 0)
4072
4073 tr.QueueForIdleConnForTesting()
4074 if tr.IsIdleForTesting() {
4075 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4076 }
4077 if !tr.PutIdleTestConn("http", "example.com") {
4078 t.Fatal("after re-activation")
4079 }
4080 wantIdle("after final put", 1)
4081 }
4082
4083
4084
4085 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4086 tr := &Transport{}
4087 wantIdle := func(when string, n int) bool {
4088 got := tr.IdleConnCountForTesting("https", "example.com:443")
4089 if got == n {
4090 return true
4091 }
4092 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4093 return false
4094 }
4095 wantIdle("start", 0)
4096 alt := funcRoundTripper(func() {})
4097 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4098 t.Fatal("put failed")
4099 }
4100 wantIdle("after put", 1)
4101 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4102 GotConn: func(httptrace.GotConnInfo) {
4103
4104 t.Error("GotConn called")
4105 },
4106 })
4107 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4108 _, err := tr.RoundTrip(req)
4109 if err != errFakeRoundTrip {
4110 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4111 }
4112 wantIdle("after round trip", 1)
4113 }
4114
4115 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4116 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4117 }
4118 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4119 if testing.Short() {
4120 t.Skip("skipping in short mode")
4121 }
4122
4123 timeout := 1 * time.Millisecond
4124 retry := true
4125 for retry {
4126 trFunc := func(tr *Transport) {
4127 tr.MaxConnsPerHost = 1
4128 tr.MaxIdleConnsPerHost = 1
4129 tr.IdleConnTimeout = timeout
4130 }
4131 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4132
4133 retry = false
4134 tooShort := func(err error) bool {
4135 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4136 return false
4137 }
4138 if !retry {
4139 t.Helper()
4140 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4141 timeout *= 2
4142 retry = true
4143 cst.close()
4144 }
4145 return true
4146 }
4147
4148 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4149 if tooShort(err) {
4150 continue
4151 }
4152 t.Fatalf("got error: %s", err)
4153 }
4154
4155 time.Sleep(10 * timeout)
4156 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4157 if tooShort(err) {
4158 continue
4159 }
4160 t.Fatalf("got error: %s", err)
4161 }
4162 }
4163 }
4164
4165
4166
4167
4168
4169 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4170 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4171 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4172 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4173 t.Error("Transport advertised gzip support in the Accept header")
4174 }
4175 if r.Header.Get("Range") == "" {
4176 t.Error("no Range in request")
4177 }
4178 })).ts
4179 c := ts.Client()
4180
4181 req, _ := NewRequest("GET", ts.URL, nil)
4182 req.Header.Set("Range", "bytes=7-11")
4183 res, err := c.Do(req)
4184 if err != nil {
4185 t.Fatal(err)
4186 }
4187 res.Body.Close()
4188 }
4189
4190
4191 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4192 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4193 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4194
4195 var b [1024]byte
4196 w.Write(b[:])
4197 })).ts
4198 tr := ts.Client().Transport.(*Transport)
4199
4200 req, err := NewRequest("GET", ts.URL, nil)
4201 if err != nil {
4202 t.Fatal(err)
4203 }
4204 res, err := tr.RoundTrip(req)
4205 if err != nil {
4206 t.Fatal(err)
4207 }
4208
4209
4210
4211 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4212 t.Fatal(err)
4213 }
4214
4215 req2, err := NewRequest("GET", ts.URL, nil)
4216 if err != nil {
4217 t.Fatal(err)
4218 }
4219 tr.CancelRequest(req)
4220 res, err = tr.RoundTrip(req2)
4221 if err != nil {
4222 t.Fatal(err)
4223 }
4224 res.Body.Close()
4225 }
4226
4227
4228 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4229 run(t, testTransportContentEncodingCaseInsensitive)
4230 }
4231 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4232 for _, ce := range []string{"gzip", "GZIP"} {
4233 ce := ce
4234 t.Run(ce, func(t *testing.T) {
4235 const encodedString = "Hello Gopher"
4236 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4237 w.Header().Set("Content-Encoding", ce)
4238 gz := gzip.NewWriter(w)
4239 gz.Write([]byte(encodedString))
4240 gz.Close()
4241 })).ts
4242
4243 res, err := ts.Client().Get(ts.URL)
4244 if err != nil {
4245 t.Fatal(err)
4246 }
4247
4248 body, err := io.ReadAll(res.Body)
4249 res.Body.Close()
4250 if err != nil {
4251 t.Fatal(err)
4252 }
4253
4254 if string(body) != encodedString {
4255 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4256 }
4257 })
4258 }
4259 }
4260
4261 func TestTransportDialCancelRace(t *testing.T) {
4262 run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
4263 }
4264 func testTransportDialCancelRace(t *testing.T, mode testMode) {
4265 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
4266 tr := ts.Client().Transport.(*Transport)
4267
4268 req, err := NewRequest("GET", ts.URL, nil)
4269 if err != nil {
4270 t.Fatal(err)
4271 }
4272 SetEnterRoundTripHook(func() {
4273 tr.CancelRequest(req)
4274 })
4275 defer SetEnterRoundTripHook(nil)
4276 res, err := tr.RoundTrip(req)
4277 if err != ExportErrRequestCanceled {
4278 t.Errorf("expected canceled request error; got %v", err)
4279 if err == nil {
4280 res.Body.Close()
4281 }
4282 }
4283 }
4284
4285
4286 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4287 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4288 }
4289 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4290 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4291 func(tr *Transport) {
4292 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4293
4294 return &funcConn{
4295 read: func([]byte) (int, error) {
4296 return 0, errors.New("error")
4297 },
4298 write: func([]byte) (int, error) {
4299 return 0, errors.New("error")
4300 },
4301 }, nil
4302 }
4303 },
4304 ).ts
4305
4306
4307
4308
4309
4310 SetEnterRoundTripHook(func() {
4311 time.Sleep(1 * time.Millisecond)
4312 })
4313 defer SetEnterRoundTripHook(nil)
4314 var closes int
4315 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4316 if err == nil {
4317 t.Fatalf("expected request to fail, but it did not")
4318 }
4319 if closes != 1 {
4320 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4321 }
4322 }
4323
4324
4325
4326
4327 type logWritesConn struct {
4328 net.Conn
4329
4330 w io.Writer
4331
4332 rch <-chan io.Reader
4333 r io.Reader
4334
4335 mu sync.Mutex
4336 writes []string
4337 }
4338
4339 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4340 c.mu.Lock()
4341 defer c.mu.Unlock()
4342 c.writes = append(c.writes, string(p))
4343 return c.w.Write(p)
4344 }
4345
4346 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4347 if c.r == nil {
4348 c.r = <-c.rch
4349 }
4350 return c.r.Read(p)
4351 }
4352
4353 func (c *logWritesConn) Close() error { return nil }
4354
4355
4356 func TestTransportFlushesBodyChunks(t *testing.T) {
4357 defer afterTest(t)
4358 resBody := make(chan io.Reader, 1)
4359 connr, connw := io.Pipe()
4360 lw := &logWritesConn{
4361 rch: resBody,
4362 w: connw,
4363 }
4364 tr := &Transport{
4365 Dial: func(network, addr string) (net.Conn, error) {
4366 return lw, nil
4367 },
4368 }
4369 bodyr, bodyw := io.Pipe()
4370 go func() {
4371 defer bodyw.Close()
4372 for i := 0; i < 3; i++ {
4373 fmt.Fprintf(bodyw, "num%d\n", i)
4374 }
4375 }()
4376 resc := make(chan *Response)
4377 go func() {
4378 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4379 req.Header.Set("User-Agent", "x")
4380 res, err := tr.RoundTrip(req)
4381 if err != nil {
4382 t.Errorf("RoundTrip: %v", err)
4383 close(resc)
4384 return
4385 }
4386 resc <- res
4387
4388 }()
4389
4390 req, err := ReadRequest(bufio.NewReader(connr))
4391 if err != nil {
4392 t.Fatal(err)
4393 }
4394 io.Copy(io.Discard, req.Body)
4395
4396
4397 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4398 res, ok := <-resc
4399 if !ok {
4400 return
4401 }
4402 defer res.Body.Close()
4403
4404 want := []string{
4405 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4406 "5\r\nnum0\n\r\n",
4407 "5\r\nnum1\n\r\n",
4408 "5\r\nnum2\n\r\n",
4409 "0\r\n\r\n",
4410 }
4411 if !reflect.DeepEqual(lw.writes, want) {
4412 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4413 }
4414 }
4415
4416
4417 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4418 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4419 gotReq := make(chan struct{})
4420 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4421 close(gotReq)
4422 }))
4423
4424 pr, pw := io.Pipe()
4425 req, err := NewRequest("POST", cst.ts.URL, pr)
4426 if err != nil {
4427 t.Fatal(err)
4428 }
4429 gotRes := make(chan struct{})
4430 go func() {
4431 defer close(gotRes)
4432 res, err := cst.tr.RoundTrip(req)
4433 if err != nil {
4434 t.Error(err)
4435 return
4436 }
4437 res.Body.Close()
4438 }()
4439
4440 <-gotReq
4441 pw.Close()
4442 <-gotRes
4443 }
4444
4445 type wgReadCloser struct {
4446 io.Reader
4447 wg *sync.WaitGroup
4448 closed bool
4449 }
4450
4451 func (c *wgReadCloser) Close() error {
4452 if c.closed {
4453 return net.ErrClosed
4454 }
4455 c.closed = true
4456 c.wg.Done()
4457 return nil
4458 }
4459
4460
4461 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4462
4463 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4464 }
4465 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4466 if testing.Short() {
4467 t.Skip("skipping in short mode")
4468 }
4469
4470 runTimeSensitiveTest(t, []time.Duration{
4471 1 * time.Millisecond,
4472 5 * time.Millisecond,
4473 10 * time.Millisecond,
4474 50 * time.Millisecond,
4475 100 * time.Millisecond,
4476 500 * time.Millisecond,
4477 time.Second,
4478 5 * time.Second,
4479 }, func(t *testing.T, timeout time.Duration) error {
4480 SetRSTAvoidanceDelay(t, timeout)
4481 t.Logf("set RST avoidance delay to %v", timeout)
4482
4483 const contentLengthLimit = 1024 * 1024
4484 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4485 if r.ContentLength >= contentLengthLimit {
4486 w.WriteHeader(StatusBadRequest)
4487 r.Body.Close()
4488 return
4489 }
4490 w.WriteHeader(StatusOK)
4491 }))
4492
4493
4494 defer cst.close()
4495 ts := cst.ts
4496 c := ts.Client()
4497
4498 count := 100
4499
4500 bigBody := strings.Repeat("a", contentLengthLimit*2)
4501 var wg sync.WaitGroup
4502 defer wg.Wait()
4503 getBody := func() (io.ReadCloser, error) {
4504 wg.Add(1)
4505 body := &wgReadCloser{
4506 Reader: strings.NewReader(bigBody),
4507 wg: &wg,
4508 }
4509 return body, nil
4510 }
4511
4512 for i := 0; i < count; i++ {
4513 reqBody, _ := getBody()
4514 req, err := NewRequest("PUT", ts.URL, reqBody)
4515 if err != nil {
4516 reqBody.Close()
4517 t.Fatal(err)
4518 }
4519 req.ContentLength = int64(len(bigBody))
4520 req.GetBody = getBody
4521
4522 resp, err := c.Do(req)
4523 if err != nil {
4524 return fmt.Errorf("Do %d: %v", i, err)
4525 } else {
4526 resp.Body.Close()
4527 if resp.StatusCode != 400 {
4528 t.Errorf("Expected status code 400, got %v", resp.Status)
4529 }
4530 }
4531 }
4532 return nil
4533 })
4534 }
4535
4536 func TestTransportAutomaticHTTP2(t *testing.T) {
4537 testTransportAutoHTTP(t, &Transport{}, true)
4538 }
4539
4540 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4541 testTransportAutoHTTP(t, &Transport{
4542 ForceAttemptHTTP2: true,
4543 TLSClientConfig: new(tls.Config),
4544 }, true)
4545 }
4546
4547
4548 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4549 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4550 }
4551
4552 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4553 testTransportAutoHTTP(t, &Transport{
4554 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4555 }, false)
4556 }
4557
4558 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4559 testTransportAutoHTTP(t, &Transport{
4560 TLSClientConfig: new(tls.Config),
4561 }, false)
4562 }
4563
4564 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4565 testTransportAutoHTTP(t, &Transport{
4566 ExpectContinueTimeout: 1 * time.Second,
4567 }, true)
4568 }
4569
4570 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4571 var d net.Dialer
4572 testTransportAutoHTTP(t, &Transport{
4573 Dial: d.Dial,
4574 }, false)
4575 }
4576
4577 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4578 var d net.Dialer
4579 testTransportAutoHTTP(t, &Transport{
4580 DialContext: d.DialContext,
4581 }, false)
4582 }
4583
4584 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4585 testTransportAutoHTTP(t, &Transport{
4586 DialTLS: func(network, addr string) (net.Conn, error) {
4587 panic("unused")
4588 },
4589 }, false)
4590 }
4591
4592 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4593 CondSkipHTTP2(t)
4594 _, err := tr.RoundTrip(new(Request))
4595 if err == nil {
4596 t.Error("expected error from RoundTrip")
4597 }
4598 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4599 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4600 }
4601 }
4602
4603
4604
4605
4606
4607
4608
4609
4610 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4611 run(t, testTransportReuseConnEmptyResponseBody)
4612 }
4613 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4614 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4615 w.Header().Set("X-Addr", r.RemoteAddr)
4616
4617 }))
4618 n := 100
4619 if testing.Short() {
4620 n = 10
4621 }
4622 var firstAddr string
4623 for i := 0; i < n; i++ {
4624 res, err := cst.c.Get(cst.ts.URL)
4625 if err != nil {
4626 log.Fatal(err)
4627 }
4628 addr := res.Header.Get("X-Addr")
4629 if i == 0 {
4630 firstAddr = addr
4631 } else if addr != firstAddr {
4632 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4633 }
4634 res.Body.Close()
4635 }
4636 }
4637
4638
4639 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4640 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4641 if err != nil {
4642 t.Fatal(err)
4643 }
4644 ln := newLocalListener(t)
4645 defer ln.Close()
4646
4647 var wg sync.WaitGroup
4648 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4649 defer SetPendingDialHooks(nil, nil)
4650
4651 testDone := make(chan struct{})
4652 defer close(testDone)
4653 go func() {
4654 tln := tls.NewListener(ln, &tls.Config{
4655 NextProtos: []string{"foo"},
4656 Certificates: []tls.Certificate{cert},
4657 })
4658 sc, err := tln.Accept()
4659 if err != nil {
4660 t.Error(err)
4661 return
4662 }
4663 if err := sc.(*tls.Conn).Handshake(); err != nil {
4664 t.Error(err)
4665 return
4666 }
4667 <-testDone
4668 sc.Close()
4669 }()
4670
4671 addr := ln.Addr().String()
4672
4673 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4674 cancel := make(chan struct{})
4675 req.Cancel = cancel
4676
4677 doReturned := make(chan bool, 1)
4678 madeRoundTripper := make(chan bool, 1)
4679
4680 tr := &Transport{
4681 DisableKeepAlives: true,
4682 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4683 "foo": func(authority string, c *tls.Conn) RoundTripper {
4684 madeRoundTripper <- true
4685 return funcRoundTripper(func() {
4686 t.Error("foo RoundTripper should not be called")
4687 })
4688 },
4689 },
4690 Dial: func(_, _ string) (net.Conn, error) {
4691 panic("shouldn't be called")
4692 },
4693 DialTLS: func(_, _ string) (net.Conn, error) {
4694 tc, err := tls.Dial("tcp", addr, &tls.Config{
4695 InsecureSkipVerify: true,
4696 NextProtos: []string{"foo"},
4697 })
4698 if err != nil {
4699 return nil, err
4700 }
4701 if err := tc.Handshake(); err != nil {
4702 return nil, err
4703 }
4704 close(cancel)
4705 <-doReturned
4706 return tc, nil
4707 },
4708 }
4709 c := &Client{Transport: tr}
4710
4711 _, err = c.Do(req)
4712 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4713 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4714 }
4715
4716 doReturned <- true
4717 <-madeRoundTripper
4718 wg.Wait()
4719 }
4720
4721 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4722 run(t, func(t *testing.T, mode testMode) {
4723 testTransportReuseConnection_Gzip(t, mode, true)
4724 })
4725 }
4726
4727 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4728 run(t, func(t *testing.T, mode testMode) {
4729 testTransportReuseConnection_Gzip(t, mode, false)
4730 })
4731 }
4732
4733
4734 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4735 addr := make(chan string, 2)
4736 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4737 addr <- r.RemoteAddr
4738 w.Header().Set("Content-Encoding", "gzip")
4739 if chunked {
4740 w.(Flusher).Flush()
4741 }
4742 w.Write(rgz)
4743 })).ts
4744 c := ts.Client()
4745
4746 trace := &httptrace.ClientTrace{
4747 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4748 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4749 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4750 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4751 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4752 }
4753 ctx := httptrace.WithClientTrace(context.Background(), trace)
4754
4755 for i := 0; i < 2; i++ {
4756 req, _ := NewRequest("GET", ts.URL, nil)
4757 req = req.WithContext(ctx)
4758 res, err := c.Do(req)
4759 if err != nil {
4760 t.Fatal(err)
4761 }
4762 buf := make([]byte, len(rgz))
4763 if n, err := io.ReadFull(res.Body, buf); err != nil {
4764 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4765 }
4766
4767
4768
4769 }
4770 a1, a2 := <-addr, <-addr
4771 if a1 != a2 {
4772 t.Fatalf("didn't reuse connection")
4773 }
4774 }
4775
4776 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4777 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4778 if mode == http2Mode {
4779 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4780 }
4781 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4782 if r.URL.Path == "/long" {
4783 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4784 }
4785 })).ts
4786 c := ts.Client()
4787 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4788
4789 if res, err := c.Get(ts.URL); err != nil {
4790 t.Fatal(err)
4791 } else {
4792 res.Body.Close()
4793 }
4794
4795 res, err := c.Get(ts.URL + "/long")
4796 if err == nil {
4797 defer res.Body.Close()
4798 var n int64
4799 for k, vv := range res.Header {
4800 for _, v := range vv {
4801 n += int64(len(k)) + int64(len(v))
4802 }
4803 }
4804 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4805 }
4806 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4807 t.Errorf("got error: %v; want %q", err, want)
4808 }
4809 }
4810
4811 func TestTransportEventTrace(t *testing.T) {
4812 run(t, func(t *testing.T, mode testMode) {
4813 testTransportEventTrace(t, mode, false)
4814 }, testNotParallel)
4815 }
4816
4817
4818 func TestTransportEventTrace_NoHooks(t *testing.T) {
4819 run(t, func(t *testing.T, mode testMode) {
4820 testTransportEventTrace(t, mode, true)
4821 }, testNotParallel)
4822 }
4823
4824 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4825 const resBody = "some body"
4826 gotWroteReqEvent := make(chan struct{}, 500)
4827 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4828 if r.Method == "GET" {
4829
4830 return
4831 }
4832 if _, err := io.ReadAll(r.Body); err != nil {
4833 t.Error(err)
4834 }
4835 if !noHooks {
4836 <-gotWroteReqEvent
4837 }
4838 io.WriteString(w, resBody)
4839 }), func(tr *Transport) {
4840 if tr.TLSClientConfig != nil {
4841 tr.TLSClientConfig.InsecureSkipVerify = true
4842 }
4843 })
4844 defer cst.close()
4845
4846 cst.tr.ExpectContinueTimeout = 1 * time.Second
4847
4848 var mu sync.Mutex
4849 var buf strings.Builder
4850 logf := func(format string, args ...any) {
4851 mu.Lock()
4852 defer mu.Unlock()
4853 fmt.Fprintf(&buf, format, args...)
4854 buf.WriteByte('\n')
4855 }
4856
4857 addrStr := cst.ts.Listener.Addr().String()
4858 ip, port, err := net.SplitHostPort(addrStr)
4859 if err != nil {
4860 t.Fatal(err)
4861 }
4862
4863
4864 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4865 if host != "dns-is-faked.golang" {
4866 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4867 return nil, nil
4868 }
4869 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4870 })
4871
4872 body := "some body"
4873 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4874 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4875 trace := &httptrace.ClientTrace{
4876 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4877 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4878 GotFirstResponseByte: func() { logf("first response byte") },
4879 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4880 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4881 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4882 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4883 ConnectDone: func(network, addr string, err error) {
4884 if err != nil {
4885 t.Errorf("ConnectDone: %v", err)
4886 }
4887 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4888 },
4889 WroteHeaderField: func(key string, value []string) {
4890 logf("WroteHeaderField: %s: %v", key, value)
4891 },
4892 WroteHeaders: func() {
4893 logf("WroteHeaders")
4894 },
4895 Wait100Continue: func() { logf("Wait100Continue") },
4896 Got100Continue: func() { logf("Got100Continue") },
4897 WroteRequest: func(e httptrace.WroteRequestInfo) {
4898 logf("WroteRequest: %+v", e)
4899 gotWroteReqEvent <- struct{}{}
4900 },
4901 }
4902 if mode == http2Mode {
4903 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4904 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4905 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4906 }
4907 }
4908 if noHooks {
4909
4910 *trace = httptrace.ClientTrace{}
4911 }
4912 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4913
4914 req.Header.Set("Expect", "100-continue")
4915 res, err := cst.c.Do(req)
4916 if err != nil {
4917 t.Fatal(err)
4918 }
4919 logf("got roundtrip.response")
4920 slurp, err := io.ReadAll(res.Body)
4921 if err != nil {
4922 t.Fatal(err)
4923 }
4924 logf("consumed body")
4925 if string(slurp) != resBody || res.StatusCode != 200 {
4926 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4927 }
4928 res.Body.Close()
4929
4930 if noHooks {
4931
4932
4933
4934 return
4935 }
4936
4937 mu.Lock()
4938 got := buf.String()
4939 mu.Unlock()
4940
4941 wantOnce := func(sub string) {
4942 if strings.Count(got, sub) != 1 {
4943 t.Errorf("expected substring %q exactly once in output.", sub)
4944 }
4945 }
4946 wantOnceOrMore := func(sub string) {
4947 if strings.Count(got, sub) == 0 {
4948 t.Errorf("expected substring %q at least once in output.", sub)
4949 }
4950 }
4951 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4952 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4953 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4954 wantOnce("got conn: {")
4955 wantOnceOrMore("Connecting to tcp " + addrStr)
4956 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
4957 wantOnce("Reused:false WasIdle:false IdleTime:0s")
4958 wantOnce("first response byte")
4959 if mode == http2Mode {
4960 wantOnce("tls handshake start")
4961 wantOnce("tls handshake done")
4962 } else {
4963 wantOnce("PutIdleConn = <nil>")
4964 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
4965
4966
4967 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
4968 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
4969 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
4970 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
4971 }
4972 wantOnce("WroteHeaders")
4973 wantOnce("Wait100Continue")
4974 wantOnce("Got100Continue")
4975 wantOnce("WroteRequest: {Err:<nil>}")
4976 if strings.Contains(got, " to udp ") {
4977 t.Errorf("should not see UDP (DNS) connections")
4978 }
4979 if t.Failed() {
4980 t.Errorf("Output:\n%s", got)
4981 }
4982
4983
4984 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
4985 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4986 res, err = cst.c.Do(req)
4987 if err != nil {
4988 t.Fatal(err)
4989 }
4990 if res.StatusCode != 200 {
4991 t.Fatal(res.Status)
4992 }
4993 res.Body.Close()
4994
4995 mu.Lock()
4996 got = buf.String()
4997 mu.Unlock()
4998
4999 sub := "Getting conn for dns-is-faked.golang:"
5000 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5001 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5002 }
5003
5004 }
5005
5006 func TestTransportEventTraceTLSVerify(t *testing.T) {
5007 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5008 }
5009 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5010 var mu sync.Mutex
5011 var buf strings.Builder
5012 logf := func(format string, args ...any) {
5013 mu.Lock()
5014 defer mu.Unlock()
5015 fmt.Fprintf(&buf, format, args...)
5016 buf.WriteByte('\n')
5017 }
5018
5019 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5020 t.Error("Unexpected request")
5021 }), func(ts *httptest.Server) {
5022 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5023 logf("%s", p)
5024 return len(p), nil
5025 }), "", 0)
5026 }).ts
5027
5028 certpool := x509.NewCertPool()
5029 certpool.AddCert(ts.Certificate())
5030
5031 c := &Client{Transport: &Transport{
5032 TLSClientConfig: &tls.Config{
5033 ServerName: "dns-is-faked.golang",
5034 RootCAs: certpool,
5035 },
5036 }}
5037
5038 trace := &httptrace.ClientTrace{
5039 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5040 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5041 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5042 },
5043 }
5044
5045 req, _ := NewRequest("GET", ts.URL, nil)
5046 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5047 _, err := c.Do(req)
5048 if err == nil {
5049 t.Error("Expected request to fail TLS verification")
5050 }
5051
5052 mu.Lock()
5053 got := buf.String()
5054 mu.Unlock()
5055
5056 wantOnce := func(sub string) {
5057 if strings.Count(got, sub) != 1 {
5058 t.Errorf("expected substring %q exactly once in output.", sub)
5059 }
5060 }
5061
5062 wantOnce("TLSHandshakeStart")
5063 wantOnce("TLSHandshakeDone")
5064 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5065
5066 if t.Failed() {
5067 t.Errorf("Output:\n%s", got)
5068 }
5069 }
5070
5071 var (
5072 isDNSHijackedOnce sync.Once
5073 isDNSHijacked bool
5074 )
5075
5076 func skipIfDNSHijacked(t *testing.T) {
5077
5078
5079
5080 isDNSHijackedOnce.Do(func() {
5081 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5082 isDNSHijacked = len(addrs) != 0
5083 })
5084 if isDNSHijacked {
5085 t.Skip("skipping; test requires non-hijacking DNS server")
5086 }
5087 }
5088
5089 func TestTransportEventTraceRealDNS(t *testing.T) {
5090 skipIfDNSHijacked(t)
5091 defer afterTest(t)
5092 tr := &Transport{}
5093 defer tr.CloseIdleConnections()
5094 c := &Client{Transport: tr}
5095
5096 var mu sync.Mutex
5097 var buf strings.Builder
5098 logf := func(format string, args ...any) {
5099 mu.Lock()
5100 defer mu.Unlock()
5101 fmt.Fprintf(&buf, format, args...)
5102 buf.WriteByte('\n')
5103 }
5104
5105 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5106 trace := &httptrace.ClientTrace{
5107 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5108 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5109 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5110 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5111 }
5112 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5113
5114 resp, err := c.Do(req)
5115 if err == nil {
5116 resp.Body.Close()
5117 t.Fatal("expected error during DNS lookup")
5118 }
5119
5120 mu.Lock()
5121 got := buf.String()
5122 mu.Unlock()
5123
5124 wantSub := func(sub string) {
5125 if !strings.Contains(got, sub) {
5126 t.Errorf("expected substring %q in output.", sub)
5127 }
5128 }
5129 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5130 wantSub("DNSDone: {Addrs:[] Err:")
5131 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5132 t.Errorf("should not see Connect events")
5133 }
5134 if t.Failed() {
5135 t.Errorf("Output:\n%s", got)
5136 }
5137 }
5138
5139
5140 func TestTransportRejectsAlphaPort(t *testing.T) {
5141 res, err := Get("http://dummy.tld:123foo/bar")
5142 if err == nil {
5143 res.Body.Close()
5144 t.Fatal("unexpected success")
5145 }
5146 ue, ok := err.(*url.Error)
5147 if !ok {
5148 t.Fatalf("got %#v; want *url.Error", err)
5149 }
5150 got := ue.Err.Error()
5151 want := `invalid port ":123foo" after host`
5152 if got != want {
5153 t.Errorf("got error %q; want %q", got, want)
5154 }
5155 }
5156
5157
5158
5159 func TestTLSHandshakeTrace(t *testing.T) {
5160 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5161 }
5162 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5163 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5164
5165 var mu sync.Mutex
5166 var start, done bool
5167 trace := &httptrace.ClientTrace{
5168 TLSHandshakeStart: func() {
5169 mu.Lock()
5170 defer mu.Unlock()
5171 start = true
5172 },
5173 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5174 mu.Lock()
5175 defer mu.Unlock()
5176 done = true
5177 if err != nil {
5178 t.Fatal("Expected error to be nil but was:", err)
5179 }
5180 },
5181 }
5182
5183 c := ts.Client()
5184 req, err := NewRequest("GET", ts.URL, nil)
5185 if err != nil {
5186 t.Fatal("Unable to construct test request:", err)
5187 }
5188 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5189
5190 r, err := c.Do(req)
5191 if err != nil {
5192 t.Fatal("Unexpected error making request:", err)
5193 }
5194 r.Body.Close()
5195 mu.Lock()
5196 defer mu.Unlock()
5197 if !start {
5198 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5199 }
5200 if !done {
5201 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5202 }
5203 }
5204
5205 func TestTransportMaxIdleConns(t *testing.T) {
5206 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5207 }
5208 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5209 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5210
5211 })).ts
5212 c := ts.Client()
5213 tr := c.Transport.(*Transport)
5214 tr.MaxIdleConns = 4
5215
5216 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5217 if err != nil {
5218 t.Fatal(err)
5219 }
5220 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5221 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5222 })
5223
5224 hitHost := func(n int) {
5225 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5226 req = req.WithContext(ctx)
5227 res, err := c.Do(req)
5228 if err != nil {
5229 t.Fatal(err)
5230 }
5231 res.Body.Close()
5232 }
5233 for i := 0; i < 4; i++ {
5234 hitHost(i)
5235 }
5236 want := []string{
5237 "|http|host-0.dns-is-faked.golang:" + port,
5238 "|http|host-1.dns-is-faked.golang:" + port,
5239 "|http|host-2.dns-is-faked.golang:" + port,
5240 "|http|host-3.dns-is-faked.golang:" + port,
5241 }
5242 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5243 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5244 }
5245
5246
5247 hitHost(4)
5248 want = []string{
5249 "|http|host-1.dns-is-faked.golang:" + port,
5250 "|http|host-2.dns-is-faked.golang:" + port,
5251 "|http|host-3.dns-is-faked.golang:" + port,
5252 "|http|host-4.dns-is-faked.golang:" + port,
5253 }
5254 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5255 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5256 }
5257 }
5258
5259 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5260 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5261 if testing.Short() {
5262 t.Skip("skipping in short mode")
5263 }
5264
5265 timeout := 1 * time.Millisecond
5266 timeoutLoop:
5267 for {
5268 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5269
5270 }))
5271 tr := cst.tr
5272 tr.IdleConnTimeout = timeout
5273 defer tr.CloseIdleConnections()
5274 c := &Client{Transport: tr}
5275
5276 idleConns := func() []string {
5277 if mode == http2Mode {
5278 return tr.IdleConnStrsForTesting_h2()
5279 } else {
5280 return tr.IdleConnStrsForTesting()
5281 }
5282 }
5283
5284 var conn string
5285 doReq := func(n int) (timeoutOk bool) {
5286 req, _ := NewRequest("GET", cst.ts.URL, nil)
5287 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5288 PutIdleConn: func(err error) {
5289 if err != nil {
5290 t.Errorf("failed to keep idle conn: %v", err)
5291 }
5292 },
5293 }))
5294 res, err := c.Do(req)
5295 if err != nil {
5296 if strings.Contains(err.Error(), "use of closed network connection") {
5297 t.Logf("req %v: connection closed prematurely", n)
5298 return false
5299 }
5300 }
5301 res.Body.Close()
5302 conns := idleConns()
5303 if len(conns) != 1 {
5304 if len(conns) == 0 {
5305 t.Logf("req %v: no idle conns", n)
5306 return false
5307 }
5308 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5309 }
5310 if conn == "" {
5311 conn = conns[0]
5312 }
5313 if conn != conns[0] {
5314 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5315 return false
5316 }
5317 return true
5318 }
5319 for i := 0; i < 3; i++ {
5320 if !doReq(i) {
5321 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5322 timeout *= 2
5323 cst.close()
5324 continue timeoutLoop
5325 }
5326 time.Sleep(timeout / 2)
5327 }
5328
5329 waitCondition(t, timeout/2, func(d time.Duration) bool {
5330 if got := idleConns(); len(got) != 0 {
5331 if d >= timeout*3/2 {
5332 t.Logf("after %v, idle conns = %q", d, got)
5333 }
5334 return false
5335 }
5336 return true
5337 })
5338 break
5339 }
5340 }
5341
5342
5343
5344
5345
5346
5347
5348
5349
5350
5351
5352
5353 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5354 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5355 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5356
5357 }))
5358
5359 ctx, cancel := context.WithCancel(context.Background())
5360 defer cancel()
5361
5362 sawDoErr := make(chan bool, 1)
5363 testDone := make(chan struct{})
5364 defer close(testDone)
5365
5366 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5367 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5368 c, err := tls.Dial(network, addr, &tls.Config{
5369 InsecureSkipVerify: true,
5370 NextProtos: []string{"h2"},
5371 })
5372 if err != nil {
5373 t.Error(err)
5374 return nil, err
5375 }
5376 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5377 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5378 c.Close()
5379 return nil, errors.New("bogus")
5380 }
5381
5382 cancel()
5383
5384 select {
5385 case <-sawDoErr:
5386 case <-testDone:
5387 }
5388 return c, nil
5389 }
5390
5391 req, _ := NewRequest("GET", cst.ts.URL, nil)
5392 req = req.WithContext(ctx)
5393 res, err := cst.c.Do(req)
5394 if err == nil {
5395 res.Body.Close()
5396 t.Fatal("unexpected success")
5397 }
5398 sawDoErr <- true
5399
5400
5401 time.Sleep(cst.tr.IdleConnTimeout * 10)
5402 }
5403
5404 type funcConn struct {
5405 net.Conn
5406 read func([]byte) (int, error)
5407 write func([]byte) (int, error)
5408 }
5409
5410 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5411 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5412 func (c funcConn) Close() error { return nil }
5413
5414
5415
5416 func TestTransportReturnsPeekError(t *testing.T) {
5417 errValue := errors.New("specific error value")
5418
5419 wrote := make(chan struct{})
5420 var wroteOnce sync.Once
5421
5422 tr := &Transport{
5423 Dial: func(network, addr string) (net.Conn, error) {
5424 c := funcConn{
5425 read: func([]byte) (int, error) {
5426 <-wrote
5427 return 0, errValue
5428 },
5429 write: func(p []byte) (int, error) {
5430 wroteOnce.Do(func() { close(wrote) })
5431 return len(p), nil
5432 },
5433 }
5434 return c, nil
5435 },
5436 }
5437 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5438 if err != errValue {
5439 t.Errorf("error = %#v; want %v", err, errValue)
5440 }
5441 }
5442
5443
5444 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5445 func testTransportIDNA(t *testing.T, mode testMode) {
5446 const uniDomain = "гофер.го"
5447 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5448
5449 var port string
5450 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5451 want := punyDomain + ":" + port
5452 if r.Host != want {
5453 t.Errorf("Host header = %q; want %q", r.Host, want)
5454 }
5455 if mode == http2Mode {
5456 if r.TLS == nil {
5457 t.Errorf("r.TLS == nil")
5458 } else if r.TLS.ServerName != punyDomain {
5459 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5460 }
5461 }
5462 w.Header().Set("Hit-Handler", "1")
5463 }), func(tr *Transport) {
5464 if tr.TLSClientConfig != nil {
5465 tr.TLSClientConfig.InsecureSkipVerify = true
5466 }
5467 })
5468
5469 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5470 if err != nil {
5471 t.Fatal(err)
5472 }
5473
5474
5475 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5476 if host != punyDomain {
5477 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5478 return nil, nil
5479 }
5480 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5481 })
5482
5483 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5484 trace := &httptrace.ClientTrace{
5485 GetConn: func(hostPort string) {
5486 want := net.JoinHostPort(punyDomain, port)
5487 if hostPort != want {
5488 t.Errorf("getting conn for %q; want %q", hostPort, want)
5489 }
5490 },
5491 DNSStart: func(e httptrace.DNSStartInfo) {
5492 if e.Host != punyDomain {
5493 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5494 }
5495 },
5496 }
5497 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5498
5499 res, err := cst.tr.RoundTrip(req)
5500 if err != nil {
5501 t.Fatal(err)
5502 }
5503 defer res.Body.Close()
5504 if res.Header.Get("Hit-Handler") != "1" {
5505 out, err := httputil.DumpResponse(res, true)
5506 if err != nil {
5507 t.Fatal(err)
5508 }
5509 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5510 }
5511 }
5512
5513
5514 func TestTransportProxyConnectHeader(t *testing.T) {
5515 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5516 }
5517 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5518 reqc := make(chan *Request, 1)
5519 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5520 if r.Method != "CONNECT" {
5521 t.Errorf("method = %q; want CONNECT", r.Method)
5522 }
5523 reqc <- r
5524 c, _, err := w.(Hijacker).Hijack()
5525 if err != nil {
5526 t.Errorf("Hijack: %v", err)
5527 return
5528 }
5529 c.Close()
5530 })).ts
5531
5532 c := ts.Client()
5533 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5534 return url.Parse(ts.URL)
5535 }
5536 c.Transport.(*Transport).ProxyConnectHeader = Header{
5537 "User-Agent": {"foo"},
5538 "Other": {"bar"},
5539 }
5540
5541 res, err := c.Get("https://dummy.tld/")
5542 if err == nil {
5543 res.Body.Close()
5544 t.Errorf("unexpected success")
5545 }
5546
5547 r := <-reqc
5548 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5549 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5550 }
5551 if got, want := r.Header.Get("Other"), "bar"; got != want {
5552 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5553 }
5554 }
5555
5556 func TestTransportProxyGetConnectHeader(t *testing.T) {
5557 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5558 }
5559 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5560 reqc := make(chan *Request, 1)
5561 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5562 if r.Method != "CONNECT" {
5563 t.Errorf("method = %q; want CONNECT", r.Method)
5564 }
5565 reqc <- r
5566 c, _, err := w.(Hijacker).Hijack()
5567 if err != nil {
5568 t.Errorf("Hijack: %v", err)
5569 return
5570 }
5571 c.Close()
5572 })).ts
5573
5574 c := ts.Client()
5575 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5576 return url.Parse(ts.URL)
5577 }
5578
5579 c.Transport.(*Transport).ProxyConnectHeader = Header{
5580 "User-Agent": {"foo"},
5581 "Other": {"bar"},
5582 }
5583 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5584 return Header{
5585 "User-Agent": {"foo2"},
5586 "Other": {"bar2"},
5587 }, nil
5588 }
5589
5590 res, err := c.Get("https://dummy.tld/")
5591 if err == nil {
5592 res.Body.Close()
5593 t.Errorf("unexpected success")
5594 }
5595
5596 r := <-reqc
5597 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5598 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5599 }
5600 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5601 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5602 }
5603 }
5604
5605 var errFakeRoundTrip = errors.New("fake roundtrip")
5606
5607 type funcRoundTripper func()
5608
5609 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5610 fn()
5611 return nil, errFakeRoundTrip
5612 }
5613
5614 func wantBody(res *Response, err error, want string) error {
5615 if err != nil {
5616 return err
5617 }
5618 slurp, err := io.ReadAll(res.Body)
5619 if err != nil {
5620 return fmt.Errorf("error reading body: %v", err)
5621 }
5622 if string(slurp) != want {
5623 return fmt.Errorf("body = %q; want %q", slurp, want)
5624 }
5625 if err := res.Body.Close(); err != nil {
5626 return fmt.Errorf("body Close = %v", err)
5627 }
5628 return nil
5629 }
5630
5631 func newLocalListener(t *testing.T) net.Listener {
5632 ln, err := net.Listen("tcp", "127.0.0.1:0")
5633 if err != nil {
5634 ln, err = net.Listen("tcp6", "[::1]:0")
5635 }
5636 if err != nil {
5637 t.Fatal(err)
5638 }
5639 return ln
5640 }
5641
5642 type countCloseReader struct {
5643 n *int
5644 io.Reader
5645 }
5646
5647 func (cr countCloseReader) Close() error {
5648 (*cr.n)++
5649 return nil
5650 }
5651
5652
5653 var rgz = []byte{
5654 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5655 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5656 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5657 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5658 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5659 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5660 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5661 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5662 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5663 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5664 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5665 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5666 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5667 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5668 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5669 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5670 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5671 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5672 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5673 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5674 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5675 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5676 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5677 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5678 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5679 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5680 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5681 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5682 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5683 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5684 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5685 0x00, 0x00,
5686 }
5687
5688
5689
5690 func TestMissingStatusNoPanic(t *testing.T) {
5691 t.Parallel()
5692
5693 const want = "unknown status code"
5694
5695 ln := newLocalListener(t)
5696 addr := ln.Addr().String()
5697 done := make(chan bool)
5698 fullAddrURL := fmt.Sprintf("http://%s", addr)
5699 raw := "HTTP/1.1 400\r\n" +
5700 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5701 "Content-Type: text/html; charset=utf-8\r\n" +
5702 "Content-Length: 10\r\n" +
5703 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5704 "Vary: Accept-Encoding\r\n\r\n" +
5705 "Aloha Olaa"
5706
5707 go func() {
5708 defer close(done)
5709
5710 conn, _ := ln.Accept()
5711 if conn != nil {
5712 io.WriteString(conn, raw)
5713 io.ReadAll(conn)
5714 conn.Close()
5715 }
5716 }()
5717
5718 proxyURL, err := url.Parse(fullAddrURL)
5719 if err != nil {
5720 t.Fatalf("proxyURL: %v", err)
5721 }
5722
5723 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5724
5725 req, _ := NewRequest("GET", "https://golang.org/", nil)
5726 res, err, panicked := doFetchCheckPanic(tr, req)
5727 if panicked {
5728 t.Error("panicked, expecting an error")
5729 }
5730 if res != nil && res.Body != nil {
5731 io.Copy(io.Discard, res.Body)
5732 res.Body.Close()
5733 }
5734
5735 if err == nil || !strings.Contains(err.Error(), want) {
5736 t.Errorf("got=%v want=%q", err, want)
5737 }
5738
5739 ln.Close()
5740 <-done
5741 }
5742
5743 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5744 defer func() {
5745 if r := recover(); r != nil {
5746 panicked = true
5747 }
5748 }()
5749 res, err = tr.RoundTrip(req)
5750 return
5751 }
5752
5753
5754
5755 func TestNoBodyOnChunked304Response(t *testing.T) {
5756 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5757 }
5758 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5759 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5760 conn, buf, _ := w.(Hijacker).Hijack()
5761 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5762 buf.Flush()
5763 conn.Close()
5764 }))
5765
5766
5767
5768
5769
5770 cst.tr.DisableKeepAlives = true
5771
5772 res, err := cst.c.Get(cst.ts.URL)
5773 if err != nil {
5774 t.Fatal(err)
5775 }
5776
5777 if res.Body != NoBody {
5778 t.Errorf("Unexpected body on 304 response")
5779 }
5780 }
5781
5782 type funcWriter func([]byte) (int, error)
5783
5784 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5785
5786 type doneContext struct {
5787 context.Context
5788 err error
5789 }
5790
5791 func (doneContext) Done() <-chan struct{} {
5792 c := make(chan struct{})
5793 close(c)
5794 return c
5795 }
5796
5797 func (d doneContext) Err() error { return d.err }
5798
5799
5800 func TestTransportCheckContextDoneEarly(t *testing.T) {
5801 tr := &Transport{}
5802 req, _ := NewRequest("GET", "http://fake.example/", nil)
5803 wantErr := errors.New("some error")
5804 req = req.WithContext(doneContext{context.Background(), wantErr})
5805 _, err := tr.RoundTrip(req)
5806 if err != wantErr {
5807 t.Errorf("error = %v; want %v", err, wantErr)
5808 }
5809 }
5810
5811
5812
5813
5814
5815
5816 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5817 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5818 }
5819 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5820 timeout := 1 * time.Millisecond
5821 for {
5822 inHandler := make(chan bool)
5823 cancelHandler := make(chan struct{})
5824 handlerDone := make(chan bool)
5825 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5826 <-r.Context().Done()
5827
5828 select {
5829 case <-cancelHandler:
5830 return
5831 case inHandler <- true:
5832 }
5833 defer func() { handlerDone <- true }()
5834
5835
5836 conn, _, err := w.(Hijacker).Hijack()
5837 if err != nil {
5838 t.Error(err)
5839 return
5840 }
5841 n, err := conn.Read([]byte{0})
5842 if n != 0 || err != io.EOF {
5843 t.Errorf("unexpected Read result: %v, %v", n, err)
5844 }
5845 conn.Close()
5846 }))
5847
5848 cst.c.Timeout = timeout
5849
5850 _, err := cst.c.Get(cst.ts.URL)
5851 if err == nil {
5852 close(cancelHandler)
5853 t.Fatal("unexpected Get success")
5854 }
5855
5856 tooSlow := time.NewTimer(timeout * 10)
5857 select {
5858 case <-tooSlow.C:
5859
5860
5861
5862 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5863 close(cancelHandler)
5864 cst.close()
5865 timeout *= 2
5866 continue
5867 case <-inHandler:
5868 tooSlow.Stop()
5869 <-handlerDone
5870 }
5871 break
5872 }
5873 }
5874
5875
5876
5877
5878
5879
5880 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5881 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5882 }
5883 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5884 inHandler := make(chan bool)
5885 cancelHandler := make(chan struct{})
5886 handlerDone := make(chan bool)
5887 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5888 w.Header().Set("Content-Length", "100")
5889 w.(Flusher).Flush()
5890
5891 select {
5892 case <-cancelHandler:
5893 return
5894 case inHandler <- true:
5895 }
5896 defer func() { handlerDone <- true }()
5897
5898 conn, _, err := w.(Hijacker).Hijack()
5899 if err != nil {
5900 t.Error(err)
5901 return
5902 }
5903 conn.Write([]byte("foo"))
5904
5905 n, err := conn.Read([]byte{0})
5906
5907
5908
5909
5910
5911 if n != 0 || err == nil {
5912 t.Errorf("unexpected Read result: %v, %v", n, err)
5913 }
5914 conn.Close()
5915 }))
5916
5917
5918
5919
5920
5921 cst.c.Timeout = 24 * time.Hour
5922 req, _ := NewRequest("GET", cst.ts.URL, nil)
5923 cancelReq := make(chan struct{})
5924 req.Cancel = cancelReq
5925
5926 res, err := cst.c.Do(req)
5927 if err != nil {
5928 close(cancelHandler)
5929 t.Fatalf("Get error: %v", err)
5930 }
5931
5932
5933
5934
5935 close(cancelReq)
5936 got, err := io.ReadAll(res.Body)
5937 if err == nil {
5938 t.Errorf("unexpected success; read %q, nil", got)
5939 }
5940
5941
5942 <-inHandler
5943 <-handlerDone
5944 }
5945
5946 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5947 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5948 }
5949 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5950 done := make(chan struct{})
5951 defer close(done)
5952 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5953 conn, _, err := w.(Hijacker).Hijack()
5954 if err != nil {
5955 t.Error(err)
5956 return
5957 }
5958 defer conn.Close()
5959 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
5960 bs := bufio.NewScanner(conn)
5961 bs.Scan()
5962 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
5963 <-done
5964 }))
5965
5966 req, _ := NewRequest("GET", cst.ts.URL, nil)
5967 req.Header.Set("Upgrade", "foo")
5968 req.Header.Set("Connection", "upgrade")
5969 res, err := cst.c.Do(req)
5970 if err != nil {
5971 t.Fatal(err)
5972 }
5973 if res.StatusCode != 101 {
5974 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
5975 }
5976 rwc, ok := res.Body.(io.ReadWriteCloser)
5977 if !ok {
5978 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
5979 }
5980 defer rwc.Close()
5981 bs := bufio.NewScanner(rwc)
5982 if !bs.Scan() {
5983 t.Fatalf("expected readable input")
5984 }
5985 if got, want := bs.Text(), "Some buffered data"; got != want {
5986 t.Errorf("read %q; want %q", got, want)
5987 }
5988 io.WriteString(rwc, "echo\n")
5989 if !bs.Scan() {
5990 t.Fatalf("expected another line")
5991 }
5992 if got, want := bs.Text(), "ECHO"; got != want {
5993 t.Errorf("read %q; want %q", got, want)
5994 }
5995 }
5996
5997 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
5998 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
5999 const target = "backend:443"
6000 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6001 if r.Method != "CONNECT" {
6002 t.Errorf("unexpected method %q", r.Method)
6003 w.WriteHeader(500)
6004 return
6005 }
6006 if r.RequestURI != target {
6007 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6008 w.WriteHeader(500)
6009 return
6010 }
6011 nc, brw, err := w.(Hijacker).Hijack()
6012 if err != nil {
6013 t.Error(err)
6014 return
6015 }
6016 defer nc.Close()
6017 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6018
6019 for {
6020 line, err := brw.ReadString('\n')
6021 if err != nil {
6022 if err != io.EOF {
6023 t.Error(err)
6024 }
6025 return
6026 }
6027 io.WriteString(brw, strings.ToUpper(line))
6028 brw.Flush()
6029 }
6030 }))
6031 pr, pw := io.Pipe()
6032 defer pw.Close()
6033 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6034 if err != nil {
6035 t.Fatal(err)
6036 }
6037 req.URL.Opaque = target
6038 res, err := cst.c.Do(req)
6039 if err != nil {
6040 t.Fatal(err)
6041 }
6042 defer res.Body.Close()
6043 if res.StatusCode != 200 {
6044 t.Fatalf("status code = %d; want 200", res.StatusCode)
6045 }
6046 br := bufio.NewReader(res.Body)
6047 for _, str := range []string{"foo", "bar", "baz"} {
6048 fmt.Fprintf(pw, "%s\n", str)
6049 got, err := br.ReadString('\n')
6050 if err != nil {
6051 t.Fatal(err)
6052 }
6053 got = strings.TrimSpace(got)
6054 want := strings.ToUpper(str)
6055 if got != want {
6056 t.Fatalf("got %q; want %q", got, want)
6057 }
6058 }
6059 }
6060
6061 func TestTransportRequestReplayable(t *testing.T) {
6062 someBody := io.NopCloser(strings.NewReader(""))
6063 tests := []struct {
6064 name string
6065 req *Request
6066 want bool
6067 }{
6068 {
6069 name: "GET",
6070 req: &Request{Method: "GET"},
6071 want: true,
6072 },
6073 {
6074 name: "GET_http.NoBody",
6075 req: &Request{Method: "GET", Body: NoBody},
6076 want: true,
6077 },
6078 {
6079 name: "GET_body",
6080 req: &Request{Method: "GET", Body: someBody},
6081 want: false,
6082 },
6083 {
6084 name: "POST",
6085 req: &Request{Method: "POST"},
6086 want: false,
6087 },
6088 {
6089 name: "POST_idempotency-key",
6090 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6091 want: true,
6092 },
6093 {
6094 name: "POST_x-idempotency-key",
6095 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6096 want: true,
6097 },
6098 {
6099 name: "POST_body",
6100 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6101 want: false,
6102 },
6103 }
6104 for _, tt := range tests {
6105 t.Run(tt.name, func(t *testing.T) {
6106 got := tt.req.ExportIsReplayable()
6107 if got != tt.want {
6108 t.Errorf("replyable = %v; want %v", got, tt.want)
6109 }
6110 })
6111 }
6112 }
6113
6114
6115
6116 type testMockTCPConn struct {
6117 *net.TCPConn
6118
6119 ReadFromCalled bool
6120 }
6121
6122 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6123 c.ReadFromCalled = true
6124 return c.TCPConn.ReadFrom(r)
6125 }
6126
6127 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6128 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6129 nBytes := int64(1 << 10)
6130 newFileFunc := func() (r io.Reader, done func(), err error) {
6131 f, err := os.CreateTemp("", "net-http-newfilefunc")
6132 if err != nil {
6133 return nil, nil, err
6134 }
6135
6136
6137 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6138 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6139 }
6140 if _, err := f.Seek(0, 0); err != nil {
6141 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6142 }
6143
6144 done = func() {
6145 f.Close()
6146 os.Remove(f.Name())
6147 }
6148
6149 return f, done, nil
6150 }
6151
6152 newBufferFunc := func() (io.Reader, func(), error) {
6153 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6154 }
6155
6156 cases := []struct {
6157 name string
6158 readerFunc func() (io.Reader, func(), error)
6159 contentLength int64
6160 expectedReadFrom bool
6161 }{
6162 {
6163 name: "file, length",
6164 readerFunc: newFileFunc,
6165 contentLength: nBytes,
6166 expectedReadFrom: true,
6167 },
6168 {
6169 name: "file, no length",
6170 readerFunc: newFileFunc,
6171 },
6172 {
6173 name: "file, negative length",
6174 readerFunc: newFileFunc,
6175 contentLength: -1,
6176 },
6177 {
6178 name: "buffer",
6179 contentLength: nBytes,
6180 readerFunc: newBufferFunc,
6181 },
6182 {
6183 name: "buffer, no length",
6184 readerFunc: newBufferFunc,
6185 },
6186 {
6187 name: "buffer, length -1",
6188 contentLength: -1,
6189 readerFunc: newBufferFunc,
6190 },
6191 }
6192
6193 for _, tc := range cases {
6194 t.Run(tc.name, func(t *testing.T) {
6195 r, cleanup, err := tc.readerFunc()
6196 if err != nil {
6197 t.Fatal(err)
6198 }
6199 defer cleanup()
6200
6201 tConn := &testMockTCPConn{}
6202 trFunc := func(tr *Transport) {
6203 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6204 var d net.Dialer
6205 conn, err := d.DialContext(ctx, network, addr)
6206 if err != nil {
6207 return nil, err
6208 }
6209
6210 tcpConn, ok := conn.(*net.TCPConn)
6211 if !ok {
6212 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6213 }
6214
6215 tConn.TCPConn = tcpConn
6216 return tConn, nil
6217 }
6218 }
6219
6220 cst := newClientServerTest(
6221 t,
6222 mode,
6223 HandlerFunc(func(w ResponseWriter, r *Request) {
6224 io.Copy(io.Discard, r.Body)
6225 r.Body.Close()
6226 w.WriteHeader(200)
6227 }),
6228 trFunc,
6229 )
6230
6231 req, err := NewRequest("PUT", cst.ts.URL, r)
6232 if err != nil {
6233 t.Fatal(err)
6234 }
6235 req.ContentLength = tc.contentLength
6236 req.Header.Set("Content-Type", "application/octet-stream")
6237 resp, err := cst.c.Do(req)
6238 if err != nil {
6239 t.Fatal(err)
6240 }
6241 defer resp.Body.Close()
6242 if resp.StatusCode != 200 {
6243 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6244 }
6245
6246 expectedReadFrom := tc.expectedReadFrom
6247 if mode != http1Mode {
6248 expectedReadFrom = false
6249 }
6250 if !tConn.ReadFromCalled && expectedReadFrom {
6251 t.Fatalf("did not call ReadFrom")
6252 }
6253
6254 if tConn.ReadFromCalled && !expectedReadFrom {
6255 t.Fatalf("ReadFrom was unexpectedly invoked")
6256 }
6257 })
6258 }
6259 }
6260
6261 func TestTransportClone(t *testing.T) {
6262 tr := &Transport{
6263 Proxy: func(*Request) (*url.URL, error) { panic("") },
6264 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6265 return nil
6266 },
6267 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6268 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6269 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6270 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6271 TLSClientConfig: new(tls.Config),
6272 TLSHandshakeTimeout: time.Second,
6273 DisableKeepAlives: true,
6274 DisableCompression: true,
6275 MaxIdleConns: 1,
6276 MaxIdleConnsPerHost: 1,
6277 MaxConnsPerHost: 1,
6278 IdleConnTimeout: time.Second,
6279 ResponseHeaderTimeout: time.Second,
6280 ExpectContinueTimeout: time.Second,
6281 ProxyConnectHeader: Header{},
6282 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6283 MaxResponseHeaderBytes: 1,
6284 ForceAttemptHTTP2: true,
6285 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6286 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6287 },
6288 ReadBufferSize: 1,
6289 WriteBufferSize: 1,
6290 }
6291 tr2 := tr.Clone()
6292 rv := reflect.ValueOf(tr2).Elem()
6293 rt := rv.Type()
6294 for i := 0; i < rt.NumField(); i++ {
6295 sf := rt.Field(i)
6296 if !token.IsExported(sf.Name) {
6297 continue
6298 }
6299 if rv.Field(i).IsZero() {
6300 t.Errorf("cloned field t2.%s is zero", sf.Name)
6301 }
6302 }
6303
6304 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6305 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6306 }
6307
6308
6309 tr = new(Transport)
6310 tr2 = tr.Clone()
6311 if tr2.TLSNextProto != nil {
6312 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6313 }
6314 }
6315
6316 func TestIs408(t *testing.T) {
6317 tests := []struct {
6318 in string
6319 want bool
6320 }{
6321 {"HTTP/1.0 408", true},
6322 {"HTTP/1.1 408", true},
6323 {"HTTP/1.8 408", true},
6324 {"HTTP/2.0 408", false},
6325 {"HTTP/1.1 408 ", true},
6326 {"HTTP/1.1 40", false},
6327 {"http/1.0 408", false},
6328 {"HTTP/1-1 408", false},
6329 }
6330 for _, tt := range tests {
6331 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6332 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6333 }
6334 }
6335 }
6336
6337 func TestTransportIgnores408(t *testing.T) {
6338 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6339 }
6340 func testTransportIgnores408(t *testing.T, mode testMode) {
6341
6342 defer log.SetOutput(log.Writer())
6343
6344 var logout strings.Builder
6345 log.SetOutput(&logout)
6346
6347 const target = "backend:443"
6348
6349 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6350 nc, _, err := w.(Hijacker).Hijack()
6351 if err != nil {
6352 t.Error(err)
6353 return
6354 }
6355 defer nc.Close()
6356 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6357 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6358 }))
6359 req, err := NewRequest("GET", cst.ts.URL, nil)
6360 if err != nil {
6361 t.Fatal(err)
6362 }
6363 res, err := cst.c.Do(req)
6364 if err != nil {
6365 t.Fatal(err)
6366 }
6367 slurp, err := io.ReadAll(res.Body)
6368 if err != nil {
6369 t.Fatal(err)
6370 }
6371 if err != nil {
6372 t.Fatal(err)
6373 }
6374 if string(slurp) != "ok" {
6375 t.Fatalf("got %q; want ok", slurp)
6376 }
6377
6378 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6379 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6380 if d > 0 {
6381 t.Logf("%v idle conns still present after %v", n, d)
6382 }
6383 return false
6384 }
6385 return true
6386 })
6387 if got := logout.String(); got != "" {
6388 t.Fatalf("expected no log output; got: %s", got)
6389 }
6390 }
6391
6392 func TestInvalidHeaderResponse(t *testing.T) {
6393 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6394 }
6395 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6396 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6397 conn, buf, _ := w.(Hijacker).Hijack()
6398 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6399 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6400 "Content-Type: text/html; charset=utf-8\r\n" +
6401 "Content-Length: 0\r\n" +
6402 "Foo : bar\r\n\r\n"))
6403 buf.Flush()
6404 conn.Close()
6405 }))
6406 res, err := cst.c.Get(cst.ts.URL)
6407 if err != nil {
6408 t.Fatal(err)
6409 }
6410 defer res.Body.Close()
6411 if v := res.Header.Get("Foo"); v != "" {
6412 t.Errorf(`unexpected "Foo" header: %q`, v)
6413 }
6414 if v := res.Header.Get("Foo "); v != "bar" {
6415 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6416 }
6417 }
6418
6419 type bodyCloser bool
6420
6421 func (bc *bodyCloser) Close() error {
6422 *bc = true
6423 return nil
6424 }
6425 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6426 return 0, io.EOF
6427 }
6428
6429
6430
6431 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6432 run(t, testTransportClosesBodyOnInvalidRequests)
6433 }
6434 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6435 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6436 t.Errorf("Should not have been invoked")
6437 })).ts
6438
6439 u, _ := url.Parse(cst.URL)
6440
6441 tests := []struct {
6442 name string
6443 req *Request
6444 wantErr string
6445 }{
6446 {
6447 name: "invalid method",
6448 req: &Request{
6449 Method: " ",
6450 URL: u,
6451 },
6452 wantErr: `invalid method " "`,
6453 },
6454 {
6455 name: "nil URL",
6456 req: &Request{
6457 Method: "GET",
6458 },
6459 wantErr: `nil Request.URL`,
6460 },
6461 {
6462 name: "invalid header key",
6463 req: &Request{
6464 Method: "GET",
6465 Header: Header{"💡": {"emoji"}},
6466 URL: u,
6467 },
6468 wantErr: `invalid header field name "💡"`,
6469 },
6470 {
6471 name: "invalid header value",
6472 req: &Request{
6473 Method: "POST",
6474 Header: Header{"key": {"\x19"}},
6475 URL: u,
6476 },
6477 wantErr: `invalid header field value for "key"`,
6478 },
6479 {
6480 name: "non HTTP(s) scheme",
6481 req: &Request{
6482 Method: "POST",
6483 URL: &url.URL{Scheme: "faux"},
6484 },
6485 wantErr: `unsupported protocol scheme "faux"`,
6486 },
6487 {
6488 name: "no Host in URL",
6489 req: &Request{
6490 Method: "POST",
6491 URL: &url.URL{Scheme: "http"},
6492 },
6493 wantErr: `no Host in request URL`,
6494 },
6495 }
6496
6497 for _, tt := range tests {
6498 t.Run(tt.name, func(t *testing.T) {
6499 var bc bodyCloser
6500 req := tt.req
6501 req.Body = &bc
6502 _, err := cst.Client().Do(tt.req)
6503 if err == nil {
6504 t.Fatal("Expected an error")
6505 }
6506 if !bc {
6507 t.Fatal("Expected body to have been closed")
6508 }
6509 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6510 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6511 }
6512 })
6513 }
6514 }
6515
6516
6517
6518 type breakableConn struct {
6519 net.Conn
6520 *brokenState
6521 }
6522
6523 type brokenState struct {
6524 sync.Mutex
6525 broken bool
6526 }
6527
6528 func (w *breakableConn) Write(b []byte) (n int, err error) {
6529 w.Lock()
6530 defer w.Unlock()
6531 if w.broken {
6532 return 0, errors.New("some write error")
6533 }
6534 return w.Conn.Write(b)
6535 }
6536
6537
6538 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6539 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6540 }
6541 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6542 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6543
6544 var brokenState brokenState
6545
6546 const numReqs = 5
6547 var numDials, gotConns uint32
6548
6549 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6550 atomic.AddUint32(&numDials, 1)
6551 c, err := net.Dial(netw, addr)
6552 if err != nil {
6553 t.Errorf("unexpected Dial error: %v", err)
6554 return nil, err
6555 }
6556 return &breakableConn{c, &brokenState}, err
6557 }
6558
6559 for i := 1; i <= numReqs; i++ {
6560 brokenState.Lock()
6561 brokenState.broken = false
6562 brokenState.Unlock()
6563
6564
6565
6566
6567 doBreak := i != numReqs
6568
6569 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6570 GotConn: func(info httptrace.GotConnInfo) {
6571 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6572 atomic.AddUint32(&gotConns, 1)
6573 },
6574 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6575 brokenState.Lock()
6576 defer brokenState.Unlock()
6577 if doBreak {
6578 brokenState.broken = true
6579 }
6580 },
6581 })
6582 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6583 if err != nil {
6584 t.Fatal(err)
6585 }
6586 _, err = cst.c.Do(req)
6587 if doBreak != (err != nil) {
6588 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6589 }
6590 }
6591 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6592 t.Errorf("GotConn calls = %v; want %v", got, want)
6593 }
6594 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6595 t.Errorf("Dials = %v; want %v", got, want)
6596 }
6597 }
6598
6599
6600
6601
6602
6603 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6604 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6605 }
6606 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6607 CondSkipHTTP2(t)
6608
6609 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6610 _, err := w.Write([]byte("foo"))
6611 if err != nil {
6612 t.Fatalf("Write: %v", err)
6613 }
6614 })
6615
6616 ts := newClientServerTest(t, mode, h).ts
6617
6618 c := ts.Client()
6619 tr := c.Transport.(*Transport)
6620 tr.MaxConnsPerHost = 1
6621
6622 errCh := make(chan error, 300)
6623 doReq := func() {
6624 resp, err := c.Get(ts.URL)
6625 if err != nil {
6626 errCh <- fmt.Errorf("request failed: %v", err)
6627 return
6628 }
6629 defer resp.Body.Close()
6630 _, err = io.ReadAll(resp.Body)
6631 if err != nil {
6632 errCh <- fmt.Errorf("read body failed: %v", err)
6633 }
6634 }
6635
6636 var wg sync.WaitGroup
6637 for i := 0; i < 300; i++ {
6638 wg.Add(1)
6639 go func() {
6640 defer wg.Done()
6641 doReq()
6642 }()
6643 }
6644 wg.Wait()
6645 close(errCh)
6646
6647 for err := range errCh {
6648 t.Errorf("error occurred: %v", err)
6649 }
6650 }
6651
6652
6653
6654
6655 func TestAltProtoCancellation(t *testing.T) {
6656 defer afterTest(t)
6657 tr := &Transport{}
6658 c := &Client{
6659 Transport: tr,
6660 Timeout: time.Millisecond,
6661 }
6662 tr.RegisterProtocol("cancel", cancelProto{})
6663 _, err := c.Get("cancel://bar.com/path")
6664 if err == nil {
6665 t.Error("request unexpectedly succeeded")
6666 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6667 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6668 }
6669 }
6670
6671 var errCancelProto = errors.New("canceled as expected")
6672
6673 type cancelProto struct{}
6674
6675 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6676 <-req.Cancel
6677 return nil, errCancelProto
6678 }
6679
6680 type roundTripFunc func(r *Request) (*Response, error)
6681
6682 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6683
6684
6685 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6686 func testIssue32441(t *testing.T, mode testMode) {
6687 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6688 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6689 t.Error("body length is zero")
6690 }
6691 })).ts
6692 c := ts.Client()
6693 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6694
6695 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6696 t.Error("body length is zero during round trip")
6697 }
6698 return nil, ErrSkipAltProtocol
6699 }))
6700 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6701 t.Error(err)
6702 }
6703 }
6704
6705
6706
6707 func TestTransportRejectsSignInContentLength(t *testing.T) {
6708 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6709 }
6710 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6711 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6712 w.Header().Set("Content-Length", "+3")
6713 w.Write([]byte("abc"))
6714 })).ts
6715
6716 c := cst.Client()
6717 res, err := c.Get(cst.URL)
6718 if err == nil || res != nil {
6719 t.Fatal("Expected a non-nil error and a nil http.Response")
6720 }
6721 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6722 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6723 }
6724 }
6725
6726
6727 type dumpConn struct {
6728 io.Writer
6729 io.Reader
6730 }
6731
6732 func (c *dumpConn) Close() error { return nil }
6733 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6734 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6735 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6736 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6737 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6738
6739
6740
6741 type delegateReader struct {
6742 c chan io.Reader
6743 r io.Reader
6744 }
6745
6746 func (r *delegateReader) Read(p []byte) (int, error) {
6747 if r.r == nil {
6748 var ok bool
6749 if r.r, ok = <-r.c; !ok {
6750 return 0, errors.New("delegate closed")
6751 }
6752 }
6753 return r.r.Read(p)
6754 }
6755
6756 func testTransportRace(req *Request) {
6757 save := req.Body
6758 pr, pw := io.Pipe()
6759 defer pr.Close()
6760 defer pw.Close()
6761 dr := &delegateReader{c: make(chan io.Reader)}
6762
6763 t := &Transport{
6764 Dial: func(net, addr string) (net.Conn, error) {
6765 return &dumpConn{pw, dr}, nil
6766 },
6767 }
6768 defer t.CloseIdleConnections()
6769
6770 quitReadCh := make(chan struct{})
6771
6772 go func() {
6773 defer close(quitReadCh)
6774
6775 req, err := ReadRequest(bufio.NewReader(pr))
6776 if err == nil {
6777
6778
6779 io.Copy(io.Discard, req.Body)
6780 req.Body.Close()
6781 }
6782 select {
6783 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6784 case quitReadCh <- struct{}{}:
6785
6786 close(dr.c)
6787 }
6788 }()
6789
6790 t.RoundTrip(req)
6791
6792
6793
6794 pw.Close()
6795 <-quitReadCh
6796
6797 req.Body = save
6798 }
6799
6800
6801
6802
6803
6804 func TestErrorWriteLoopRace(t *testing.T) {
6805 if testing.Short() {
6806 return
6807 }
6808 t.Parallel()
6809 for i := 0; i < 1000; i++ {
6810 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6811 ctx, cancel := context.WithTimeout(context.Background(), delay)
6812 defer cancel()
6813
6814 r := bytes.NewBuffer(make([]byte, 10000))
6815 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6816 if err != nil {
6817 t.Fatal(err)
6818 }
6819
6820 testTransportRace(req)
6821 }
6822 }
6823
6824
6825
6826
6827 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6828 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6829 }
6830 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6831 reqc := make(chan chan struct{}, 2)
6832 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6833 ch := make(chan struct{}, 1)
6834 reqc <- ch
6835 <-ch
6836 w.Header().Add("Content-Length", "0")
6837 })).ts
6838
6839 client := ts.Client()
6840 transport := client.Transport.(*Transport)
6841 transport.MaxIdleConns = 1
6842 transport.MaxConnsPerHost = 1
6843
6844 var wg sync.WaitGroup
6845
6846 wg.Add(1)
6847 putidlec := make(chan chan struct{}, 1)
6848 reqerrc := make(chan error, 1)
6849 go func() {
6850 defer wg.Done()
6851 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6852 PutIdleConn: func(error) {
6853
6854
6855 ch := make(chan struct{})
6856 putidlec <- ch
6857 close(putidlec)
6858 <-ch
6859 },
6860 })
6861 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6862 res, err := client.Do(req)
6863 reqerrc <- err
6864 if err == nil {
6865 res.Body.Close()
6866 }
6867 }()
6868
6869
6870
6871 r1c := <-reqc
6872 close(r1c)
6873 var idlec chan struct{}
6874 select {
6875 case err := <-reqerrc:
6876 if err != nil {
6877 t.Fatalf("request 1: got err %v, want nil", err)
6878 }
6879 idlec = <-putidlec
6880 case idlec = <-putidlec:
6881 }
6882
6883 wg.Add(1)
6884 cancelctx, cancel := context.WithCancel(context.Background())
6885 go func() {
6886 defer wg.Done()
6887 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6888 res, err := client.Do(req)
6889 if err == nil {
6890 res.Body.Close()
6891 }
6892 if !errors.Is(err, context.Canceled) {
6893 t.Errorf("request 2: got err %v, want Canceled", err)
6894 }
6895
6896
6897 close(idlec)
6898 }()
6899
6900
6901
6902 r2c := <-reqc
6903 cancel()
6904
6905 <-idlec
6906
6907 close(r2c)
6908 wg.Wait()
6909 }
6910
6911 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6912 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6913 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6914 go io.Copy(io.Discard, req.Body)
6915 panic(ErrAbortHandler)
6916 })).ts
6917
6918 var wg sync.WaitGroup
6919 for i := 0; i < 2; i++ {
6920 wg.Add(1)
6921 go func() {
6922 defer wg.Done()
6923 for j := 0; j < 10; j++ {
6924 const reqLen = 6 * 1024 * 1024
6925 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6926 req.ContentLength = reqLen
6927 resp, _ := ts.Client().Transport.RoundTrip(req)
6928 if resp != nil {
6929 resp.Body.Close()
6930 }
6931 }
6932 }()
6933 }
6934 wg.Wait()
6935 }
6936
6937 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6938 func testRequestSanitization(t *testing.T, mode testMode) {
6939 if mode == http2Mode {
6940
6941 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6942 }
6943 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6944 if h, ok := req.Header["X-Evil"]; ok {
6945 t.Errorf("request has X-Evil header: %q", h)
6946 }
6947 })).ts
6948 req, _ := NewRequest("GET", ts.URL, nil)
6949 req.Host = "go.dev\r\nX-Evil:evil"
6950 resp, _ := ts.Client().Do(req)
6951 if resp != nil {
6952 resp.Body.Close()
6953 }
6954 }
6955
6956 func TestProxyAuthHeader(t *testing.T) {
6957
6958 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
6959 }
6960 func testProxyAuthHeader(t *testing.T, mode testMode) {
6961 const username = "u"
6962 const password = "@/?!"
6963 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6964
6965
6966 var r2 Request
6967 r2.Header = Header{
6968 "Authorization": req.Header["Proxy-Authorization"],
6969 }
6970 gotuser, gotpass, ok := r2.BasicAuth()
6971 if !ok || gotuser != username || gotpass != password {
6972 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
6973 }
6974 }))
6975 u, err := url.Parse(cst.ts.URL)
6976 if err != nil {
6977 t.Fatal(err)
6978 }
6979 u.User = url.UserPassword(username, password)
6980 t.Setenv("HTTP_PROXY", u.String())
6981 cst.tr.Proxy = ProxyURL(u)
6982 resp, err := cst.c.Get("http://_/")
6983 if err != nil {
6984 t.Fatal(err)
6985 }
6986 resp.Body.Close()
6987 }
6988
View as plain text