1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "internal/godebug"
14 "io"
15 "log"
16 "mime"
17 "net"
18 "net/http"
19 "net/http/httptrace"
20 "net/http/internal/ascii"
21 "net/textproto"
22 "net/url"
23 "strings"
24 "sync"
25 "sync/atomic"
26 "time"
27
28 "golang.org/x/net/http/httpguts"
29 )
30
31
32 type ProxyRequest struct {
33
34
35 In *http.Request
36
37
38
39
40
41 Out *http.Request
42 }
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 func (r *ProxyRequest) SetURL(target *url.URL) {
59 rewriteRequestURL(r.Out, target)
60 r.Out.Host = ""
61 }
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 func (r *ProxyRequest) SetXForwarded() {
83 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
84 if err == nil {
85 prior := r.Out.Header["X-Forwarded-For"]
86 if len(prior) > 0 {
87 clientIP = strings.Join(prior, ", ") + ", " + clientIP
88 }
89 r.Out.Header.Set("X-Forwarded-For", clientIP)
90 } else {
91 r.Out.Header.Del("X-Forwarded-For")
92 }
93 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
94 if r.In.TLS == nil {
95 r.Out.Header.Set("X-Forwarded-Proto", "http")
96 } else {
97 r.Out.Header.Set("X-Forwarded-Proto", "https")
98 }
99 }
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114 type ReverseProxy struct {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 Rewrite func(*ProxyRequest)
137
138
139
140 Transport http.RoundTripper
141
142
143
144
145
146
147
148
149
150
151
152 FlushInterval time.Duration
153
154
155
156
157 ErrorLog *log.Logger
158
159
160
161
162 BufferPool BufferPool
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177 ModifyResponse func(*http.Response) error
178
179
180
181
182
183
184 ErrorHandler func(http.ResponseWriter, *http.Request, error)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266 Director func(*http.Request)
267 }
268
269
270
271 type BufferPool interface {
272 Get() []byte
273 Put([]byte)
274 }
275
276 func singleJoiningSlash(a, b string) string {
277 aslash := strings.HasSuffix(a, "/")
278 bslash := strings.HasPrefix(b, "/")
279 switch {
280 case aslash && bslash:
281 return a + b[1:]
282 case !aslash && !bslash:
283 return a + "/" + b
284 }
285 return a + b
286 }
287
288 func joinURLPath(a, b *url.URL) (path, rawpath string) {
289 if a.RawPath == "" && b.RawPath == "" {
290 return singleJoiningSlash(a.Path, b.Path), ""
291 }
292
293
294 apath := a.EscapedPath()
295 bpath := b.EscapedPath()
296
297 aslash := strings.HasSuffix(apath, "/")
298 bslash := strings.HasPrefix(bpath, "/")
299
300 switch {
301 case aslash && bslash:
302 return a.Path + b.Path[1:], apath + bpath[1:]
303 case !aslash && !bslash:
304 return a.Path + "/" + b.Path, apath + "/" + bpath
305 }
306 return a.Path + b.Path, apath + bpath
307 }
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
334 director := func(req *http.Request) {
335 rewriteRequestURL(req, target)
336 }
337 return &ReverseProxy{Director: director}
338 }
339
340 func rewriteRequestURL(req *http.Request, target *url.URL) {
341 targetQuery := target.RawQuery
342 req.URL.Scheme = target.Scheme
343 req.URL.Host = target.Host
344 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
345 if targetQuery == "" || req.URL.RawQuery == "" {
346 req.URL.RawQuery = targetQuery + req.URL.RawQuery
347 } else {
348 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
349 }
350 }
351
352 func copyHeader(dst, src http.Header) {
353 for k, vv := range src {
354 for _, v := range vv {
355 dst.Add(k, v)
356 }
357 }
358 }
359
360
361
362
363
364
365 var hopHeaders = []string{
366 "Connection",
367 "Proxy-Connection",
368 "Keep-Alive",
369 "Proxy-Authenticate",
370 "Proxy-Authorization",
371 "Te",
372 "Trailer",
373 "Transfer-Encoding",
374 "Upgrade",
375 }
376
377 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
378 p.logf("http: proxy error: %v", err)
379 rw.WriteHeader(http.StatusBadGateway)
380 }
381
382 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
383 if p.ErrorHandler != nil {
384 return p.ErrorHandler
385 }
386 return p.defaultErrorHandler
387 }
388
389
390
391 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
392 if p.ModifyResponse == nil {
393 return true
394 }
395 if err := p.ModifyResponse(res); err != nil {
396 res.Body.Close()
397 p.getErrorHandler()(rw, req, err)
398 return false
399 }
400 return true
401 }
402
403 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
404 transport := p.Transport
405 if transport == nil {
406 transport = http.DefaultTransport
407 }
408
409 ctx := req.Context()
410 if ctx.Done() != nil {
411
412
413
414
415
416
417
418
419
420
421 } else if cn, ok := rw.(http.CloseNotifier); ok {
422 var cancel context.CancelFunc
423 ctx, cancel = context.WithCancel(ctx)
424 defer cancel()
425 notifyChan := cn.CloseNotify()
426 go func() {
427 select {
428 case <-notifyChan:
429 cancel()
430 case <-ctx.Done():
431 }
432 }()
433 }
434
435 outreq := req.Clone(ctx)
436 if req.ContentLength == 0 {
437 outreq.Body = nil
438 }
439 if outreq.Body != nil {
440
441
442
443
444
445
446
447
448
449
450
451 outreq.Body = &noopCloseReader{readCloser: outreq.Body}
452
453
454
455
456
457
458 defer outreq.Body.Close()
459 }
460 if outreq.Header == nil {
461 outreq.Header = make(http.Header)
462 }
463
464 if (p.Director != nil) == (p.Rewrite != nil) {
465 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
466 return
467 }
468
469 if p.Director != nil {
470 p.Director(outreq)
471 if outreq.Form != nil {
472 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
473 }
474 }
475 outreq.Close = false
476
477 reqUpType := upgradeType(outreq.Header)
478 if !ascii.IsPrint(reqUpType) {
479 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
480 return
481 }
482 removeHopByHopHeaders(outreq.Header)
483
484
485
486
487
488
489 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
490 outreq.Header.Set("Te", "trailers")
491 }
492
493
494
495 if reqUpType != "" {
496 outreq.Header.Set("Connection", "Upgrade")
497 outreq.Header.Set("Upgrade", reqUpType)
498 }
499
500 if p.Rewrite != nil {
501
502
503
504 outreq.Header.Del("Forwarded")
505 outreq.Header.Del("X-Forwarded-For")
506 outreq.Header.Del("X-Forwarded-Host")
507 outreq.Header.Del("X-Forwarded-Proto")
508
509
510 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
511
512 pr := &ProxyRequest{
513 In: req,
514 Out: outreq,
515 }
516 p.Rewrite(pr)
517 outreq = pr.Out
518 } else {
519 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
520
521
522
523 prior, ok := outreq.Header["X-Forwarded-For"]
524 omit := ok && prior == nil
525 if len(prior) > 0 {
526 clientIP = strings.Join(prior, ", ") + ", " + clientIP
527 }
528 if !omit {
529 outreq.Header.Set("X-Forwarded-For", clientIP)
530 }
531 }
532 }
533
534 if _, ok := outreq.Header["User-Agent"]; !ok {
535
536
537 outreq.Header.Set("User-Agent", "")
538 }
539
540 var (
541 roundTripMutex sync.Mutex
542 roundTripDone bool
543 )
544 trace := &httptrace.ClientTrace{
545 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
546 roundTripMutex.Lock()
547 defer roundTripMutex.Unlock()
548 if roundTripDone {
549
550
551 return nil
552 }
553 h := rw.Header()
554 copyHeader(h, http.Header(header))
555 rw.WriteHeader(code)
556
557
558 clear(h)
559 return nil
560 },
561 }
562 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
563
564 res, err := transport.RoundTrip(outreq)
565 roundTripMutex.Lock()
566 roundTripDone = true
567 roundTripMutex.Unlock()
568 if err != nil {
569 p.getErrorHandler()(rw, outreq, err)
570 return
571 }
572
573
574 if res.StatusCode == http.StatusSwitchingProtocols {
575 if !p.modifyResponse(rw, res, outreq) {
576 return
577 }
578 p.handleUpgradeResponse(rw, outreq, res)
579 return
580 }
581
582 removeHopByHopHeaders(res.Header)
583
584 if !p.modifyResponse(rw, res, outreq) {
585 return
586 }
587
588 copyHeader(rw.Header(), res.Header)
589
590
591
592 announcedTrailers := len(res.Trailer)
593 if announcedTrailers > 0 {
594 trailerKeys := make([]string, 0, len(res.Trailer))
595 for k := range res.Trailer {
596 trailerKeys = append(trailerKeys, k)
597 }
598 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
599 }
600
601 rw.WriteHeader(res.StatusCode)
602
603 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
604 if err != nil {
605 defer res.Body.Close()
606
607
608
609 if !shouldPanicOnCopyError(req) {
610 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
611 return
612 }
613 panic(http.ErrAbortHandler)
614 }
615 res.Body.Close()
616
617 if len(res.Trailer) > 0 {
618
619
620
621 http.NewResponseController(rw).Flush()
622 }
623
624 if len(res.Trailer) == announcedTrailers {
625 copyHeader(rw.Header(), res.Trailer)
626 return
627 }
628
629 for k, vv := range res.Trailer {
630 k = http.TrailerPrefix + k
631 for _, v := range vv {
632 rw.Header().Add(k, v)
633 }
634 }
635 }
636
637 var inOurTests bool
638
639
640
641
642
643
644 func shouldPanicOnCopyError(req *http.Request) bool {
645 if inOurTests {
646
647 return true
648 }
649 if req.Context().Value(http.ServerContextKey) != nil {
650
651
652 return true
653 }
654
655
656 return false
657 }
658
659
660 func removeHopByHopHeaders(h http.Header) {
661
662 for _, f := range h["Connection"] {
663 for sf := range strings.SplitSeq(f, ",") {
664 if sf = textproto.TrimString(sf); sf != "" {
665 h.Del(sf)
666 }
667 }
668 }
669
670
671
672 for _, f := range hopHeaders {
673 h.Del(f)
674 }
675 }
676
677
678
679 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
680 resCT := res.Header.Get("Content-Type")
681
682
683
684 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
685 return -1
686 }
687
688
689 if res.ContentLength == -1 {
690 return -1
691 }
692
693 return p.FlushInterval
694 }
695
696 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
697 var w io.Writer = dst
698
699 if flushInterval != 0 {
700 mlw := &maxLatencyWriter{
701 dst: dst,
702 flush: http.NewResponseController(dst).Flush,
703 latency: flushInterval,
704 }
705 defer mlw.stop()
706
707
708 mlw.flushPending = true
709 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
710
711 w = mlw
712 }
713
714 var buf []byte
715 if p.BufferPool != nil {
716 buf = p.BufferPool.Get()
717 defer p.BufferPool.Put(buf)
718 }
719 _, err := p.copyBuffer(w, src, buf)
720 return err
721 }
722
723
724
725 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
726 if len(buf) == 0 {
727 buf = make([]byte, 32*1024)
728 }
729 var written int64
730 for {
731 nr, rerr := src.Read(buf)
732 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
733 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
734 }
735 if nr > 0 {
736 nw, werr := dst.Write(buf[:nr])
737 if nw > 0 {
738 written += int64(nw)
739 }
740 if werr != nil {
741 return written, werr
742 }
743 if nr != nw {
744 return written, io.ErrShortWrite
745 }
746 }
747 if rerr != nil {
748 if rerr == io.EOF {
749 rerr = nil
750 }
751 return written, rerr
752 }
753 }
754 }
755
756 func (p *ReverseProxy) logf(format string, args ...any) {
757 if p.ErrorLog != nil {
758 p.ErrorLog.Printf(format, args...)
759 } else {
760 log.Printf(format, args...)
761 }
762 }
763
764 type maxLatencyWriter struct {
765 dst io.Writer
766 flush func() error
767 latency time.Duration
768
769 mu sync.Mutex
770 t *time.Timer
771 flushPending bool
772 }
773
774 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
775 m.mu.Lock()
776 defer m.mu.Unlock()
777 n, err = m.dst.Write(p)
778 if m.latency < 0 {
779 m.flush()
780 return
781 }
782 if m.flushPending {
783 return
784 }
785 if m.t == nil {
786 m.t = time.AfterFunc(m.latency, m.delayedFlush)
787 } else {
788 m.t.Reset(m.latency)
789 }
790 m.flushPending = true
791 return
792 }
793
794 func (m *maxLatencyWriter) delayedFlush() {
795 m.mu.Lock()
796 defer m.mu.Unlock()
797 if !m.flushPending {
798 return
799 }
800 m.flush()
801 m.flushPending = false
802 }
803
804 func (m *maxLatencyWriter) stop() {
805 m.mu.Lock()
806 defer m.mu.Unlock()
807 m.flushPending = false
808 if m.t != nil {
809 m.t.Stop()
810 }
811 }
812
813 func upgradeType(h http.Header) string {
814 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
815 return ""
816 }
817 return h.Get("Upgrade")
818 }
819
820 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
821 reqUpType := upgradeType(req.Header)
822 resUpType := upgradeType(res.Header)
823 if !ascii.IsPrint(resUpType) {
824 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
825 return
826 }
827 if !ascii.EqualFold(reqUpType, resUpType) {
828 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
829 return
830 }
831
832 backConn, ok := res.Body.(io.ReadWriteCloser)
833 if !ok {
834 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
835 return
836 }
837
838 rc := http.NewResponseController(rw)
839 conn, brw, hijackErr := rc.Hijack()
840 if errors.Is(hijackErr, http.ErrNotSupported) {
841 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
842 return
843 }
844
845 backConnCloseCh := make(chan bool)
846 go func() {
847
848
849 select {
850 case <-req.Context().Done():
851 case <-backConnCloseCh:
852 }
853 backConn.Close()
854 }()
855 defer close(backConnCloseCh)
856
857 if hijackErr != nil {
858 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
859 return
860 }
861 defer conn.Close()
862
863 copyHeader(rw.Header(), res.Header)
864
865 res.Header = rw.Header()
866 res.Body = nil
867 if err := res.Write(brw); err != nil {
868 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
869 return
870 }
871 if err := brw.Flush(); err != nil {
872 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
873 return
874 }
875 errc := make(chan error, 1)
876 spc := switchProtocolCopier{user: conn, backend: backConn}
877 go spc.copyToBackend(errc)
878 go spc.copyFromBackend(errc)
879
880
881
882 err := <-errc
883 if err == nil {
884 err = <-errc
885 }
886 }
887
888 var errCopyDone = errors.New("hijacked connection copy complete")
889
890
891
892 type switchProtocolCopier struct {
893 user, backend io.ReadWriter
894 }
895
896 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
897 if _, err := io.Copy(c.user, c.backend); err != nil {
898 errc <- err
899 return
900 }
901
902
903 if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
904 errc <- wc.CloseWrite()
905 return
906 }
907
908 errc <- errCopyDone
909 }
910
911 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
912 if _, err := io.Copy(c.backend, c.user); err != nil {
913 errc <- err
914 return
915 }
916
917
918 if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
919 errc <- wc.CloseWrite()
920 return
921 }
922
923 errc <- errCopyDone
924 }
925
926 var urlmaxqueryparams = godebug.New("urlmaxqueryparams")
927
928
929 const defaultMaxParams = 10000
930
931 func cleanQueryParams(s string) string {
932 reencode := func(s string) string {
933 v, _ := url.ParseQuery(s)
934 return v.Encode()
935 }
936 if urlmaxqueryparams.Value() != "" {
937
938 return reencode(s)
939 }
940 if numParams := strings.Count(s, "&") + 1; numParams > defaultMaxParams {
941
942 return reencode(s)
943 }
944 for i := 0; i < len(s); {
945 switch s[i] {
946 case ';':
947 return reencode(s)
948 case '%':
949 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
950 return reencode(s)
951 }
952 i += 3
953 default:
954 i++
955 }
956 }
957 return s
958 }
959
960 func ishex(c byte) bool {
961 switch {
962 case '0' <= c && c <= '9':
963 return true
964 case 'a' <= c && c <= 'f':
965 return true
966 case 'A' <= c && c <= 'F':
967 return true
968 }
969 return false
970 }
971
972 type noopCloseReader struct {
973 readCloser io.ReadCloser
974 closed atomic.Bool
975 }
976
977 func (ncr *noopCloseReader) Close() error {
978 ncr.closed.Store(true)
979 return nil
980 }
981
982 func (ncr *noopCloseReader) Read(p []byte) (int, error) {
983 if ncr.closed.Load() {
984 return 0, errors.New("ReverseProxy does an invalid Read on closed Body")
985 }
986 return ncr.readCloser.Read(p)
987 }
988
View as plain text