From: qydysky Date: Wed, 13 Dec 2023 17:37:43 +0000 (+0800) Subject: 1 X-Git-Tag: v0.1.20231214143418~1 X-Git-Url: http://127.0.0.1:8081/?a=commitdiff_plain;h=3197160ee526732e5f2555194ceb848af0042423;p=front%2F.git 1 --- diff --git a/go.mod b/go.mod index 6a2b18b..a192fb2 100644 --- a/go.mod +++ b/go.mod @@ -21,5 +21,3 @@ require ( golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -// replace github.com/qydysky/part => ../part diff --git a/main.go b/main.go index 95965e6..d643ed2 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,26 @@ package front import ( + "bufio" + "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" - "math/rand" + "log" + "net" "net/http" - "net/http/cookiejar" - "strings" + "net/http/httptrace" + "net/url" "time" _ "unsafe" "github.com/gorilla/websocket" pctx "github.com/qydysky/part/ctx" pweb "github.com/qydysky/part/web" + "golang.org/x/net/proxy" ) type Logger interface { @@ -166,6 +170,9 @@ func loadConfig(buf []byte, configF File, configS *[]Config) error { return nil } +//go:linkname nanotime1 runtime.nanotime1 +func nanotime1() int64 + func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, logger Logger) error { configS.lock.RLock() defer configS.lock.RUnlock() @@ -226,7 +233,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log case 1: backI = backArray[0] default: - backI = backArray[rand.Int63()%int64(len(backArray))] + backI = backArray[nanotime1()%int64(len(backArray))] } if !backI.IsLive() { backI = nil @@ -355,36 +362,29 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route url = "ws" + url - tmpHeader := make(http.Header) - - dialer := websocket.DefaultDialer - - if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") { - dialer.EnableCompression = true - } + reqHeader := make(http.Header) - if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil { + if e := copyHeader(r.Header, reqHeader, back.ReqHeader); e != nil { logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) return e } - jar, _ := cookiejar.New(nil) - jar.SetCookies(r.URL, r.Cookies()) - - dialer.Jar = jar + fmt.Println(reqHeader.Get("Sec-WebSocket-Key")) - if res, resp, e := dialer.Dial(url, tmpHeader); e != nil { + if res, resp, e := DialContext(websocket.DefaultDialer, context.Background(), url, reqHeader); e != nil { return errors.Join(ErrReqDoFail, e) } else { defer res.Close() - clear(tmpHeader) - if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil { + resHeader := make(http.Header) + if e := copyHeader(resp.Header, resHeader, back.ResHeader); e != nil { logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) return e } - if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil { + fmt.Println(resHeader.Get("Sec-Websocket-Accept")) + + if req, e := Upgrade(&websocket.Upgrader{}, w, r, resHeader); e != nil { return errors.Join(ErrResDoFail, e) } else { defer req.Close() @@ -419,23 +419,10 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route } } -var delWsHeaders = map[string]*any{ - `Connection`: nil, - `Upgrade`: nil, - `Sec-Websocket-Accept`: nil, - `Sec-Websocket-Key`: nil, - `Sec-Websocket-Version`: nil, - `Sec-Websocket-Protocol`: nil, - `Sec-Websocket-Extensions`: nil, -} - -func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error { +func copyHeader(s, t http.Header, app []Header) error { sm := (map[string][]string)(s) tm := (map[string][]string)(t) for k, v := range sm { - if _, ok := delWsHeaders[k]; ok { - continue - } tm[k] = v } for _, v := range app { @@ -460,3 +447,442 @@ func copyWsMsg(dst io.Writer, src io.Reader) (e error) { _, e = io.Copy(dst, src) return } + +func DialContext(d *websocket.Dialer, ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) { + if d == nil { + d = websocket.DefaultDialer + } + + challengeKey := requestHeader.Get("Sec-WebSocket-Key") + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Sec-Websocket-Protocol": + return nil, nil, errors.New("websocket: header not allowed: " + k) + default: + req.Header[k] = vs + } + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + // Get network dial function. + var netDial func(network, add string) (net.Conn, error) + + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } + } + + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + if err := c.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial("tcp", hostPort) + if err != nil { + return nil, nil, err + } + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + + defer func() { + if netConn != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + } + }() + + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + + if err != nil { + return nil, nil, err + } + } + + var br *bufio.Reader + if br == nil { + if d.ReadBufferSize == 0 { + d.ReadBufferSize = defaultReadBufferSize + } else if d.ReadBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + d.ReadBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(netConn, d.ReadBufferSize) + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, br, nil) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(br, req) + if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, websocket.ErrBadHandshake + } + + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) + + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, nil, err + } + netConn = nil // to avoid close in defer. + return conn, resp, nil +} + +type netDialerFunc func(network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(network, addr) +} + +//go:linkname doHandshake github.com/gorilla/websocket.doHandshake +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error + +//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig +func cloneTLSConfig(cfg *tls.Config) *tls.Config + +//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) + +//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL +var errMalformedURL error + +//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression +var errInvalidCompression error + +//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey +func generateChallengeKey() (string, error) + +//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue +func tokenListContainsValue(header http.Header, name string, value string) bool + +//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError +func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error) + +//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin +func checkSameOrigin(r *http.Request) bool + +//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey +func isValidChallengeKey(s string) bool + +//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol +func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string + +//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions +func parseExtensions(header http.Header) []map[string]string + +//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize +func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int + +//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer +func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte + +//go:linkname newConn github.com/gorilla/websocket.newConn +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool websocket.BufferPool, br *bufio.Reader, writeBuf []byte) *websocket.Conn + +//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey +func computeAcceptKey(challengeKey string) string + +const ( + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + maxControlFramePayloadSize = 125 +) + +func Upgrade(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) { + const badHandshake = "websocket: the client is not using the websocket protocol: " + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != http.MethodGet { + return returnError(u, w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return returnError(u, w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return returnError(u, w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if !isValidChallengeKey(challengeKey) { + return returnError(u, w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") + } + + h, ok := w.(http.Hijacker) + if !ok { + return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") + } + var brw *bufio.ReadWriter + netConn, brw, err := h.Hijack() + if err != nil { + return returnError(u, w, r, http.StatusInternalServerError, err.Error()) + } + + if brw.Reader.Buffered() > 0 { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + var br *bufio.Reader + if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { + // Reuse hijacked buffered reader as connection reader. + br = brw.Reader + } + + buf := bufioWriterBuffer(netConn, brw.Writer) + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } else { + if u.WriteBufferSize <= 0 { + u.WriteBufferSize = defaultWriteBufferSize + } + u.WriteBufferSize += maxFrameHeaderSize + if writeBuf == nil && u.WriteBufferPool == nil { + writeBuf = make([]byte, u.WriteBufferSize) + } + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(writeBuf) > len(p) { + p = writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...) + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" || k == "Sec-Websocket-Extensions" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + } + if _, err = netConn.Write(p); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + } + + return c, nil +}