package front
import (
+ "bufio"
+ "bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
- "math/rand"
+ "log"
+ "net"
"net/http"
- "net/http/cookiejar"
- "strings"
+ "net/http/httptrace"
+ "net/url"
"time"
_ "unsafe"
"github.com/gorilla/websocket"
pctx "github.com/qydysky/part/ctx"
pweb "github.com/qydysky/part/web"
+ "golang.org/x/net/proxy"
)
type Logger interface {
return nil
}
+//go:linkname nanotime1 runtime.nanotime1
+func nanotime1() int64
+
func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, logger Logger) error {
configS.lock.RLock()
defer configS.lock.RUnlock()
case 1:
backI = backArray[0]
default:
- backI = backArray[rand.Int63()%int64(len(backArray))]
+ backI = backArray[nanotime1()%int64(len(backArray))]
}
if !backI.IsLive() {
backI = nil
url = "ws" + url
- tmpHeader := make(http.Header)
-
- dialer := websocket.DefaultDialer
-
- if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") {
- dialer.EnableCompression = true
- }
+ reqHeader := make(http.Header)
- if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil {
+ if e := copyHeader(r.Header, reqHeader, back.ReqHeader); e != nil {
logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
return e
}
- jar, _ := cookiejar.New(nil)
- jar.SetCookies(r.URL, r.Cookies())
-
- dialer.Jar = jar
+ fmt.Println(reqHeader.Get("Sec-WebSocket-Key"))
- if res, resp, e := dialer.Dial(url, tmpHeader); e != nil {
+ if res, resp, e := DialContext(websocket.DefaultDialer, context.Background(), url, reqHeader); e != nil {
return errors.Join(ErrReqDoFail, e)
} else {
defer res.Close()
- clear(tmpHeader)
- if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil {
+ resHeader := make(http.Header)
+ if e := copyHeader(resp.Header, resHeader, back.ResHeader); e != nil {
logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
return e
}
- if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil {
+ fmt.Println(resHeader.Get("Sec-Websocket-Accept"))
+
+ if req, e := Upgrade(&websocket.Upgrader{}, w, r, resHeader); e != nil {
return errors.Join(ErrResDoFail, e)
} else {
defer req.Close()
}
}
-var delWsHeaders = map[string]*any{
- `Connection`: nil,
- `Upgrade`: nil,
- `Sec-Websocket-Accept`: nil,
- `Sec-Websocket-Key`: nil,
- `Sec-Websocket-Version`: nil,
- `Sec-Websocket-Protocol`: nil,
- `Sec-Websocket-Extensions`: nil,
-}
-
-func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error {
+func copyHeader(s, t http.Header, app []Header) error {
sm := (map[string][]string)(s)
tm := (map[string][]string)(t)
for k, v := range sm {
- if _, ok := delWsHeaders[k]; ok {
- continue
- }
tm[k] = v
}
for _, v := range app {
_, e = io.Copy(dst, src)
return
}
+
+func DialContext(d *websocket.Dialer, ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) {
+ if d == nil {
+ d = websocket.DefaultDialer
+ }
+
+ challengeKey := requestHeader.Get("Sec-WebSocket-Key")
+
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ switch u.Scheme {
+ case "ws":
+ u.Scheme = "http"
+ case "wss":
+ u.Scheme = "https"
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if u.User != nil {
+ // User name and password are not allowed in websocket URIs.
+ return nil, nil, errMalformedURL
+ }
+
+ req := &http.Request{
+ Method: http.MethodGet,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: u.Host,
+ }
+ req = req.WithContext(ctx)
+
+ // Set the request headers using the capitalization for names and values in
+ // RFC examples. Although the capitalization shouldn't matter, there are
+ // servers that depend on it. The Header.Set method is not used because the
+ // method canonicalizes the header names.
+ for k, vs := range requestHeader {
+ switch {
+ case k == "Host":
+ if len(vs) > 0 {
+ req.Host = vs[0]
+ }
+ case k == "Sec-Websocket-Protocol":
+ return nil, nil, errors.New("websocket: header not allowed: " + k)
+ default:
+ req.Header[k] = vs
+ }
+ }
+
+ if d.HandshakeTimeout != 0 {
+ var cancel func()
+ ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+ defer cancel()
+ }
+
+ // Get network dial function.
+ var netDial func(network, add string) (net.Conn, error)
+
+ switch u.Scheme {
+ case "http":
+ if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ case "https":
+ if d.NetDialTLSContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialTLSContext(ctx, network, addr)
+ }
+ } else if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if netDial == nil {
+ netDialer := &net.Dialer{}
+ netDial = func(network, addr string) (net.Conn, error) {
+ return netDialer.DialContext(ctx, network, addr)
+ }
+ }
+
+ // If needed, wrap the dial function to set the connection deadline.
+ if deadline, ok := ctx.Deadline(); ok {
+ forwardDial := netDial
+ netDial = func(network, addr string) (net.Conn, error) {
+ c, err := forwardDial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ err = c.SetDeadline(deadline)
+ if err != nil {
+ if err := c.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ return c, nil
+ }
+ }
+
+ // If needed, wrap the dial function to connect through a proxy.
+ if d.Proxy != nil {
+ proxyURL, err := d.Proxy(req)
+ if err != nil {
+ return nil, nil, err
+ }
+ if proxyURL != nil {
+ dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
+ if err != nil {
+ return nil, nil, err
+ }
+ netDial = dialer.Dial
+ }
+ }
+
+ hostPort, hostNoPort := hostPortNoPort(u)
+ trace := httptrace.ContextClientTrace(ctx)
+ if trace != nil && trace.GetConn != nil {
+ trace.GetConn(hostPort)
+ }
+
+ netConn, err := netDial("tcp", hostPort)
+ if err != nil {
+ return nil, nil, err
+ }
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{
+ Conn: netConn,
+ })
+ }
+
+ defer func() {
+ if netConn != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ }
+ }()
+
+ if u.Scheme == "https" && d.NetDialTLSContext == nil {
+ // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
+ cfg := cloneTLSConfig(d.TLSClientConfig)
+ if cfg.ServerName == "" {
+ cfg.ServerName = hostNoPort
+ }
+ tlsConn := tls.Client(netConn, cfg)
+ netConn = tlsConn
+
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := doHandshake(ctx, tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+ }
+
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ var br *bufio.Reader
+ if br == nil {
+ if d.ReadBufferSize == 0 {
+ d.ReadBufferSize = defaultReadBufferSize
+ } else if d.ReadBufferSize < maxControlFramePayloadSize {
+ // must be large enough for control frame
+ d.ReadBufferSize = maxControlFramePayloadSize
+ }
+ br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
+ }
+
+ conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, br, nil)
+
+ if err := req.Write(netConn); err != nil {
+ return nil, nil, err
+ }
+
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
+ trace.GotFirstResponseByte()
+ }
+ }
+
+ resp, err := http.ReadResponse(br, req)
+ if err != nil {
+ if d.TLSClientConfig != nil {
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto != "http/1.1" {
+ return nil, nil, fmt.Errorf(
+ "websocket: protocol %q was given but is not supported;"+
+ "sharing tls.Config with net/http Transport can cause this error: %w",
+ proto, err,
+ )
+ }
+ }
+ }
+ return nil, nil, err
+ }
+
+ if d.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ d.Jar.SetCookies(u, rc)
+ }
+ }
+
+ if resp.StatusCode != 101 ||
+ !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+ !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+ resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+ // Before closing the network connection on return from this
+ // function, slurp up some of the response to aid application
+ // debugging.
+ buf := make([]byte, 1024)
+ n, _ := io.ReadFull(resp.Body, buf)
+ resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+ return nil, resp, websocket.ErrBadHandshake
+ }
+
+ resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
+
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ return nil, nil, err
+ }
+ netConn = nil // to avoid close in defer.
+ return conn, resp, nil
+}
+
+type netDialerFunc func(network, addr string) (net.Conn, error)
+
+func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
+ return fn(network, addr)
+}
+
+//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
+func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
+
+//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
+func cloneTLSConfig(cfg *tls.Config) *tls.Config
+
+//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
+
+//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
+var errMalformedURL error
+
+//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
+var errInvalidCompression error
+
+//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
+func generateChallengeKey() (string, error)
+
+//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
+func tokenListContainsValue(header http.Header, name string, value string) bool
+
+//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
+func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
+
+//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
+func checkSameOrigin(r *http.Request) bool
+
+//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
+func isValidChallengeKey(s string) bool
+
+//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
+func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
+
+//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
+func parseExtensions(header http.Header) []map[string]string
+
+//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
+func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
+
+//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
+func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
+
+//go:linkname newConn github.com/gorilla/websocket.newConn
+func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool websocket.BufferPool, br *bufio.Reader, writeBuf []byte) *websocket.Conn
+
+//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
+func computeAcceptKey(challengeKey string) string
+
+const (
+ maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
+ defaultReadBufferSize = 4096
+ defaultWriteBufferSize = 4096
+ maxControlFramePayloadSize = 125
+)
+
+func Upgrade(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) {
+ const badHandshake = "websocket: the client is not using the websocket protocol: "
+
+ if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
+ return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
+ }
+
+ if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
+ return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
+ }
+
+ if r.Method != http.MethodGet {
+ return returnError(u, w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
+ }
+
+ if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
+ return returnError(u, w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
+ }
+
+ checkOrigin := u.CheckOrigin
+ if checkOrigin == nil {
+ checkOrigin = checkSameOrigin
+ }
+ if !checkOrigin(r) {
+ return returnError(u, w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
+ }
+
+ challengeKey := r.Header.Get("Sec-Websocket-Key")
+ if !isValidChallengeKey(challengeKey) {
+ return returnError(u, w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
+ }
+
+ h, ok := w.(http.Hijacker)
+ if !ok {
+ return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
+ }
+ var brw *bufio.ReadWriter
+ netConn, brw, err := h.Hijack()
+ if err != nil {
+ return returnError(u, w, r, http.StatusInternalServerError, err.Error())
+ }
+
+ if brw.Reader.Buffered() > 0 {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, errors.New("websocket: client sent data before handshake is complete")
+ }
+
+ var br *bufio.Reader
+ if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
+ // Reuse hijacked buffered reader as connection reader.
+ br = brw.Reader
+ }
+
+ buf := bufioWriterBuffer(netConn, brw.Writer)
+
+ var writeBuf []byte
+ if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
+ // Reuse hijacked write buffer as connection buffer.
+ writeBuf = buf
+ } else {
+ if u.WriteBufferSize <= 0 {
+ u.WriteBufferSize = defaultWriteBufferSize
+ }
+ u.WriteBufferSize += maxFrameHeaderSize
+ if writeBuf == nil && u.WriteBufferPool == nil {
+ writeBuf = make([]byte, u.WriteBufferSize)
+ }
+ }
+
+ c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
+
+ // Use larger of hijacked buffer and connection write buffer for header.
+ p := buf
+ if len(writeBuf) > len(p) {
+ p = writeBuf
+ }
+ p = p[:0]
+
+ p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
+ for k, vs := range responseHeader {
+ if k == "Sec-Websocket-Protocol" || k == "Sec-Websocket-Extensions" {
+ continue
+ }
+ for _, v := range vs {
+ p = append(p, k...)
+ p = append(p, ": "...)
+ for i := 0; i < len(v); i++ {
+ b := v[i]
+ if b <= 31 {
+ // prevent response splitting.
+ b = ' '
+ }
+ p = append(p, b)
+ }
+ p = append(p, "\r\n"...)
+ }
+ }
+ p = append(p, "\r\n"...)
+
+ // Clear deadlines set by HTTP server.
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ }
+ if _, err = netConn.Write(p); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ }
+
+ return c, nil
+}