]> 127.0.0.1 Git - front/.git/commitdiff
1
authorqydysky <qydysky@foxmail.com>
Wed, 13 Dec 2023 17:37:43 +0000 (01:37 +0800)
committerqydysky <qydysky@foxmail.com>
Wed, 13 Dec 2023 17:37:43 +0000 (01:37 +0800)
go.mod
main.go

diff --git a/go.mod b/go.mod
index 6a2b18b9d9f3dcf4dd818901536c645c7e75fb0e..a192fb29469695e5c92eec827e18c78cfb973b22 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -21,5 +21,3 @@ require (
        golang.org/x/text v0.14.0 // indirect
        gopkg.in/yaml.v3 v3.0.1 // indirect
 )
-
-// replace github.com/qydysky/part => ../part
diff --git a/main.go b/main.go
index 95965e69e400adcd67965f6cbbca6372090dc1ef..d643ed2a9ffe453491f78b31cb2f099b6f8be279 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,22 +1,26 @@
 package front
 
 import (
+       "bufio"
+       "bytes"
        "context"
        "crypto/tls"
        "encoding/json"
        "errors"
        "fmt"
        "io"
-       "math/rand"
+       "log"
+       "net"
        "net/http"
-       "net/http/cookiejar"
-       "strings"
+       "net/http/httptrace"
+       "net/url"
        "time"
        _ "unsafe"
 
        "github.com/gorilla/websocket"
        pctx "github.com/qydysky/part/ctx"
        pweb "github.com/qydysky/part/web"
+       "golang.org/x/net/proxy"
 )
 
 type Logger interface {
@@ -166,6 +170,9 @@ func loadConfig(buf []byte, configF File, configS *[]Config) error {
        return nil
 }
 
+//go:linkname nanotime1 runtime.nanotime1
+func nanotime1() int64
+
 func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, logger Logger) error {
        configS.lock.RLock()
        defer configS.lock.RUnlock()
@@ -226,7 +233,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log
                                case 1:
                                        backI = backArray[0]
                                default:
-                                       backI = backArray[rand.Int63()%int64(len(backArray))]
+                                       backI = backArray[nanotime1()%int64(len(backArray))]
                                }
                                if !backI.IsLive() {
                                        backI = nil
@@ -355,36 +362,29 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
 
        url = "ws" + url
 
-       tmpHeader := make(http.Header)
-
-       dialer := websocket.DefaultDialer
-
-       if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") {
-               dialer.EnableCompression = true
-       }
+       reqHeader := make(http.Header)
 
-       if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil {
+       if e := copyHeader(r.Header, reqHeader, back.ReqHeader); e != nil {
                logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
                return e
        }
 
-       jar, _ := cookiejar.New(nil)
-       jar.SetCookies(r.URL, r.Cookies())
-
-       dialer.Jar = jar
+       fmt.Println(reqHeader.Get("Sec-WebSocket-Key"))
 
-       if res, resp, e := dialer.Dial(url, tmpHeader); e != nil {
+       if res, resp, e := DialContext(websocket.DefaultDialer, context.Background(), url, reqHeader); e != nil {
                return errors.Join(ErrReqDoFail, e)
        } else {
                defer res.Close()
 
-               clear(tmpHeader)
-               if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil {
+               resHeader := make(http.Header)
+               if e := copyHeader(resp.Header, resHeader, back.ResHeader); e != nil {
                        logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
                        return e
                }
 
-               if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil {
+               fmt.Println(resHeader.Get("Sec-Websocket-Accept"))
+
+               if req, e := Upgrade(&websocket.Upgrader{}, w, r, resHeader); e != nil {
                        return errors.Join(ErrResDoFail, e)
                } else {
                        defer req.Close()
@@ -419,23 +419,10 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
        }
 }
 
-var delWsHeaders = map[string]*any{
-       `Connection`:               nil,
-       `Upgrade`:                  nil,
-       `Sec-Websocket-Accept`:     nil,
-       `Sec-Websocket-Key`:        nil,
-       `Sec-Websocket-Version`:    nil,
-       `Sec-Websocket-Protocol`:   nil,
-       `Sec-Websocket-Extensions`: nil,
-}
-
-func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error {
+func copyHeader(s, t http.Header, app []Header) error {
        sm := (map[string][]string)(s)
        tm := (map[string][]string)(t)
        for k, v := range sm {
-               if _, ok := delWsHeaders[k]; ok {
-                       continue
-               }
                tm[k] = v
        }
        for _, v := range app {
@@ -460,3 +447,442 @@ func copyWsMsg(dst io.Writer, src io.Reader) (e error) {
        _, e = io.Copy(dst, src)
        return
 }
+
+func DialContext(d *websocket.Dialer, ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) {
+       if d == nil {
+               d = websocket.DefaultDialer
+       }
+
+       challengeKey := requestHeader.Get("Sec-WebSocket-Key")
+
+       u, err := url.Parse(urlStr)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       switch u.Scheme {
+       case "ws":
+               u.Scheme = "http"
+       case "wss":
+               u.Scheme = "https"
+       default:
+               return nil, nil, errMalformedURL
+       }
+
+       if u.User != nil {
+               // User name and password are not allowed in websocket URIs.
+               return nil, nil, errMalformedURL
+       }
+
+       req := &http.Request{
+               Method:     http.MethodGet,
+               URL:        u,
+               Proto:      "HTTP/1.1",
+               ProtoMajor: 1,
+               ProtoMinor: 1,
+               Header:     make(http.Header),
+               Host:       u.Host,
+       }
+       req = req.WithContext(ctx)
+
+       // Set the request headers using the capitalization for names and values in
+       // RFC examples. Although the capitalization shouldn't matter, there are
+       // servers that depend on it. The Header.Set method is not used because the
+       // method canonicalizes the header names.
+       for k, vs := range requestHeader {
+               switch {
+               case k == "Host":
+                       if len(vs) > 0 {
+                               req.Host = vs[0]
+                       }
+               case k == "Sec-Websocket-Protocol":
+                       return nil, nil, errors.New("websocket: header not allowed: " + k)
+               default:
+                       req.Header[k] = vs
+               }
+       }
+
+       if d.HandshakeTimeout != 0 {
+               var cancel func()
+               ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+               defer cancel()
+       }
+
+       // Get network dial function.
+       var netDial func(network, add string) (net.Conn, error)
+
+       switch u.Scheme {
+       case "http":
+               if d.NetDialContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialContext(ctx, network, addr)
+                       }
+               } else if d.NetDial != nil {
+                       netDial = d.NetDial
+               }
+       case "https":
+               if d.NetDialTLSContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialTLSContext(ctx, network, addr)
+                       }
+               } else if d.NetDialContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialContext(ctx, network, addr)
+                       }
+               } else if d.NetDial != nil {
+                       netDial = d.NetDial
+               }
+       default:
+               return nil, nil, errMalformedURL
+       }
+
+       if netDial == nil {
+               netDialer := &net.Dialer{}
+               netDial = func(network, addr string) (net.Conn, error) {
+                       return netDialer.DialContext(ctx, network, addr)
+               }
+       }
+
+       // If needed, wrap the dial function to set the connection deadline.
+       if deadline, ok := ctx.Deadline(); ok {
+               forwardDial := netDial
+               netDial = func(network, addr string) (net.Conn, error) {
+                       c, err := forwardDial(network, addr)
+                       if err != nil {
+                               return nil, err
+                       }
+                       err = c.SetDeadline(deadline)
+                       if err != nil {
+                               if err := c.Close(); err != nil {
+                                       log.Printf("websocket: failed to close network connection: %v", err)
+                               }
+                               return nil, err
+                       }
+                       return c, nil
+               }
+       }
+
+       // If needed, wrap the dial function to connect through a proxy.
+       if d.Proxy != nil {
+               proxyURL, err := d.Proxy(req)
+               if err != nil {
+                       return nil, nil, err
+               }
+               if proxyURL != nil {
+                       dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
+                       if err != nil {
+                               return nil, nil, err
+                       }
+                       netDial = dialer.Dial
+               }
+       }
+
+       hostPort, hostNoPort := hostPortNoPort(u)
+       trace := httptrace.ContextClientTrace(ctx)
+       if trace != nil && trace.GetConn != nil {
+               trace.GetConn(hostPort)
+       }
+
+       netConn, err := netDial("tcp", hostPort)
+       if err != nil {
+               return nil, nil, err
+       }
+       if trace != nil && trace.GotConn != nil {
+               trace.GotConn(httptrace.GotConnInfo{
+                       Conn: netConn,
+               })
+       }
+
+       defer func() {
+               if netConn != nil {
+                       if err := netConn.Close(); err != nil {
+                               log.Printf("websocket: failed to close network connection: %v", err)
+                       }
+               }
+       }()
+
+       if u.Scheme == "https" && d.NetDialTLSContext == nil {
+               // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
+               cfg := cloneTLSConfig(d.TLSClientConfig)
+               if cfg.ServerName == "" {
+                       cfg.ServerName = hostNoPort
+               }
+               tlsConn := tls.Client(netConn, cfg)
+               netConn = tlsConn
+
+               if trace != nil && trace.TLSHandshakeStart != nil {
+                       trace.TLSHandshakeStart()
+               }
+               err := doHandshake(ctx, tlsConn, cfg)
+               if trace != nil && trace.TLSHandshakeDone != nil {
+                       trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+               }
+
+               if err != nil {
+                       return nil, nil, err
+               }
+       }
+
+       var br *bufio.Reader
+       if br == nil {
+               if d.ReadBufferSize == 0 {
+                       d.ReadBufferSize = defaultReadBufferSize
+               } else if d.ReadBufferSize < maxControlFramePayloadSize {
+                       // must be large enough for control frame
+                       d.ReadBufferSize = maxControlFramePayloadSize
+               }
+               br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
+       }
+
+       conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, br, nil)
+
+       if err := req.Write(netConn); err != nil {
+               return nil, nil, err
+       }
+
+       if trace != nil && trace.GotFirstResponseByte != nil {
+               if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
+                       trace.GotFirstResponseByte()
+               }
+       }
+
+       resp, err := http.ReadResponse(br, req)
+       if err != nil {
+               if d.TLSClientConfig != nil {
+                       for _, proto := range d.TLSClientConfig.NextProtos {
+                               if proto != "http/1.1" {
+                                       return nil, nil, fmt.Errorf(
+                                               "websocket: protocol %q was given but is not supported;"+
+                                                       "sharing tls.Config with net/http Transport can cause this error: %w",
+                                               proto, err,
+                                       )
+                               }
+                       }
+               }
+               return nil, nil, err
+       }
+
+       if d.Jar != nil {
+               if rc := resp.Cookies(); len(rc) > 0 {
+                       d.Jar.SetCookies(u, rc)
+               }
+       }
+
+       if resp.StatusCode != 101 ||
+               !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+               !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+               resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+               // Before closing the network connection on return from this
+               // function, slurp up some of the response to aid application
+               // debugging.
+               buf := make([]byte, 1024)
+               n, _ := io.ReadFull(resp.Body, buf)
+               resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+               return nil, resp, websocket.ErrBadHandshake
+       }
+
+       resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
+
+       if err := netConn.SetDeadline(time.Time{}); err != nil {
+               return nil, nil, err
+       }
+       netConn = nil // to avoid close in defer.
+       return conn, resp, nil
+}
+
+type netDialerFunc func(network, addr string) (net.Conn, error)
+
+func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
+       return fn(network, addr)
+}
+
+//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
+func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
+
+//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
+func cloneTLSConfig(cfg *tls.Config) *tls.Config
+
+//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
+
+//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
+var errMalformedURL error
+
+//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
+var errInvalidCompression error
+
+//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
+func generateChallengeKey() (string, error)
+
+//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
+func tokenListContainsValue(header http.Header, name string, value string) bool
+
+//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
+func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
+
+//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
+func checkSameOrigin(r *http.Request) bool
+
+//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
+func isValidChallengeKey(s string) bool
+
+//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
+func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
+
+//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
+func parseExtensions(header http.Header) []map[string]string
+
+//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
+func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
+
+//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
+func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
+
+//go:linkname newConn github.com/gorilla/websocket.newConn
+func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool websocket.BufferPool, br *bufio.Reader, writeBuf []byte) *websocket.Conn
+
+//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
+func computeAcceptKey(challengeKey string) string
+
+const (
+       maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
+       defaultReadBufferSize      = 4096
+       defaultWriteBufferSize     = 4096
+       maxControlFramePayloadSize = 125
+)
+
+func Upgrade(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) {
+       const badHandshake = "websocket: the client is not using the websocket protocol: "
+
+       if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
+               return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
+       }
+
+       if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
+               return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
+       }
+
+       if r.Method != http.MethodGet {
+               return returnError(u, w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
+       }
+
+       if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
+               return returnError(u, w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
+       }
+
+       checkOrigin := u.CheckOrigin
+       if checkOrigin == nil {
+               checkOrigin = checkSameOrigin
+       }
+       if !checkOrigin(r) {
+               return returnError(u, w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
+       }
+
+       challengeKey := r.Header.Get("Sec-Websocket-Key")
+       if !isValidChallengeKey(challengeKey) {
+               return returnError(u, w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
+       }
+
+       h, ok := w.(http.Hijacker)
+       if !ok {
+               return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
+       }
+       var brw *bufio.ReadWriter
+       netConn, brw, err := h.Hijack()
+       if err != nil {
+               return returnError(u, w, r, http.StatusInternalServerError, err.Error())
+       }
+
+       if brw.Reader.Buffered() > 0 {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, errors.New("websocket: client sent data before handshake is complete")
+       }
+
+       var br *bufio.Reader
+       if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
+               // Reuse hijacked buffered reader as connection reader.
+               br = brw.Reader
+       }
+
+       buf := bufioWriterBuffer(netConn, brw.Writer)
+
+       var writeBuf []byte
+       if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
+               // Reuse hijacked write buffer as connection buffer.
+               writeBuf = buf
+       } else {
+               if u.WriteBufferSize <= 0 {
+                       u.WriteBufferSize = defaultWriteBufferSize
+               }
+               u.WriteBufferSize += maxFrameHeaderSize
+               if writeBuf == nil && u.WriteBufferPool == nil {
+                       writeBuf = make([]byte, u.WriteBufferSize)
+               }
+       }
+
+       c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
+
+       // Use larger of hijacked buffer and connection write buffer for header.
+       p := buf
+       if len(writeBuf) > len(p) {
+               p = writeBuf
+       }
+       p = p[:0]
+
+       p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
+       for k, vs := range responseHeader {
+               if k == "Sec-Websocket-Protocol" || k == "Sec-Websocket-Extensions" {
+                       continue
+               }
+               for _, v := range vs {
+                       p = append(p, k...)
+                       p = append(p, ": "...)
+                       for i := 0; i < len(v); i++ {
+                               b := v[i]
+                               if b <= 31 {
+                                       // prevent response splitting.
+                                       b = ' '
+                               }
+                               p = append(p, b)
+                       }
+                       p = append(p, "\r\n"...)
+               }
+       }
+       p = append(p, "\r\n"...)
+
+       // Clear deadlines set by HTTP server.
+       if err := netConn.SetDeadline(time.Time{}); err != nil {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, err
+       }
+
+       if u.HandshakeTimeout > 0 {
+               if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
+                       if err := netConn.Close(); err != nil {
+                               log.Printf("websocket: failed to close network connection: %v", err)
+                       }
+                       return nil, err
+               }
+       }
+       if _, err = netConn.Write(p); err != nil {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, err
+       }
+       if u.HandshakeTimeout > 0 {
+               if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
+                       if err := netConn.Close(); err != nil {
+                               log.Printf("websocket: failed to close network connection: %v", err)
+                       }
+                       return nil, err
+               }
+       }
+
+       return c, nil
+}