]> 127.0.0.1 Git - front/.git/commitdiff
1
authorqydysky <qydysky@foxmail.com>
Sat, 31 Aug 2024 12:52:17 +0000 (20:52 +0800)
committerqydysky <qydysky@foxmail.com>
Sat, 31 Aug 2024 12:52:17 +0000 (20:52 +0800)
go.mod
go.sum
http.go
main.go
utils/websocket.go [new file with mode: 0644]
ws.go

diff --git a/go.mod b/go.mod
index 6b5897ebfdeef583d8829fd2b096c24ebeee38f2..f3e5cf92ec3b8775b78689ed8ce9b6afb5aa622c 100755 (executable)
--- a/go.mod
+++ b/go.mod
@@ -16,6 +16,7 @@ require (
        github.com/go-ole/go-ole v1.3.0 // indirect
        github.com/google/uuid v1.6.0 // indirect
        github.com/pmezard/go-difflib v1.0.0 // indirect
+       github.com/qydysky/brotli v0.0.0-20240828134800-e9913a6e7ed9 // indirect
        github.com/shirou/gopsutil v3.21.11+incompatible // indirect
        github.com/tklauser/go-sysconf v0.3.14 // indirect
        github.com/tklauser/numcpus v0.8.0 // indirect
diff --git a/go.sum b/go.sum
index 8d98eebf1de01081eaa61f72be8796d0e72e49b0..9bc0a563f6ea3c86a3fd4259ae9b7246abc3d4dd 100755 (executable)
--- a/go.sum
+++ b/go.sum
@@ -39,6 +39,8 @@ github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZ
 github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
 github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
 github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
+github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
 github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
 github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
 golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
diff --git a/http.go b/http.go
index efc883b4bc69b1a389fe739b226ef6d34e4d1670..6a6fd62e405bed9ebe79a504f0102c40a87d6233 100644 (file)
--- a/http.go
+++ b/http.go
@@ -11,6 +11,7 @@ import (
        "time"
        _ "unsafe"
 
+       preqf "github.com/qydysky/part/reqf"
        pslice "github.com/qydysky/part/slice"
 )
 
@@ -128,7 +129,7 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
                        MaxAge: chosenBack.Splicing(),
                        Path:   "/",
                }
-               if validCookieDomain(r.Host) {
+               if preqf.ValidCookieDomain(r.Host) {
                        cookie.Domain = r.Host
                }
                w.Header().Add("Set-Cookie", (cookie).String())
diff --git a/main.go b/main.go
index 977913249af3922118c6f72c609945322b80ed31..82d461cf566cfcd1905347440d900c45e88607fe 100755 (executable)
--- a/main.go
+++ b/main.go
@@ -14,6 +14,7 @@ import (
        _ "unsafe"
 
        "github.com/qydysky/front/dealer"
+       utils "github.com/qydysky/front/utils"
        pctx "github.com/qydysky/part/ctx"
        pweb "github.com/qydysky/part/web"
 )
@@ -29,9 +30,6 @@ type File interface {
        Read(data []byte) (int, error)
 }
 
-//go:linkname validCookieDomain net/http.validCookieDomain
-func validCookieDomain(v string) bool
-
 // 加载
 func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error {
        var oldBufMd5 string
@@ -86,7 +84,7 @@ func Test(ctx context.Context, port int, logger Logger) {
                                conn, _ := Upgrade(w, r, http.Header{
                                        "Upgrade":              []string{"websocket"},
                                        "Connection":           []string{"upgrade"},
-                                       "Sec-Websocket-Accept": []string{computeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))},
+                                       "Sec-Websocket-Accept": []string{utils.ComputeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))},
                                })
                                conn.Close()
                        } else {
diff --git a/utils/websocket.go b/utils/websocket.go
new file mode 100644 (file)
index 0000000..6597e28
--- /dev/null
@@ -0,0 +1,298 @@
+package utils
+
+import (
+       "bufio"
+       "context"
+       "crypto/sha1"
+       "crypto/tls"
+       "encoding/base64"
+       "errors"
+       "io"
+       "net"
+       "net/http"
+       "net/url"
+       "strings"
+       "unicode/utf8"
+)
+
+func ValidCookieDomain(v string) bool {
+       if isCookieDomainName(v) {
+               return true
+       }
+       if net.ParseIP(v) != nil && !strings.Contains(v, ":") {
+               return true
+       }
+       return false
+}
+
+func isCookieDomainName(s string) bool {
+       if len(s) == 0 {
+               return false
+       }
+       if len(s) > 255 {
+               return false
+       }
+
+       if s[0] == '.' {
+               // A cookie a domain attribute may start with a leading dot.
+               s = s[1:]
+       }
+       last := byte('.')
+       ok := false // Ok once we've seen a letter.
+       partlen := 0
+       for i := 0; i < len(s); i++ {
+               c := s[i]
+               switch {
+               default:
+                       return false
+               case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
+                       // No '_' allowed here (in contrast to package net).
+                       ok = true
+                       partlen++
+               case '0' <= c && c <= '9':
+                       // fine
+                       partlen++
+               case c == '-':
+                       // Byte before dash cannot be dot.
+                       if last == '.' {
+                               return false
+                       }
+                       partlen++
+               case c == '.':
+                       // Byte before dot cannot be dot, dash.
+                       if last == '.' || last == '-' {
+                               return false
+                       }
+                       if partlen > 63 || partlen == 0 {
+                               return false
+                       }
+                       partlen = 0
+               }
+               last = c
+       }
+       if last == '-' || partlen > 63 {
+               return false
+       }
+
+       return ok
+}
+
+var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
+
+func ComputeAcceptKey(challengeKey string) string {
+       h := sha1.New()
+       h.Write([]byte(challengeKey))
+       h.Write(keyGUID)
+       return base64.StdEncoding.EncodeToString(h.Sum(nil))
+}
+
+// writeHook is an io.Writer that records the last slice passed to it vio
+// io.Writer.Write.
+type writeHook struct {
+       p []byte
+}
+
+func (wh *writeHook) Write(p []byte) (int, error) {
+       wh.p = p
+       return len(p), nil
+}
+
+// bufioWriterBuffer grabs the buffer from a bufio.Writer.
+func BufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
+       // This code assumes that bufio.Writer.buf[:1] is passed to the
+       // bufio.Writer's underlying writer.
+       var wh writeHook
+       bw.Reset(&wh)
+       bw.WriteByte(0)
+       bw.Flush()
+
+       bw.Reset(originalWriter)
+
+       return wh.p[:cap(wh.p)]
+}
+func DoHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
+       if err := tlsConn.HandshakeContext(ctx); err != nil {
+               return err
+       }
+       if !cfg.InsecureSkipVerify {
+               if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func CloneTLSConfig(cfg *tls.Config) *tls.Config {
+       if cfg == nil {
+               return &tls.Config{}
+       }
+       return cfg.Clone()
+}
+
+func HostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
+       hostPort = u.Host
+       hostNoPort = u.Host
+       if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
+               hostNoPort = hostNoPort[:i]
+       } else {
+               switch u.Scheme {
+               case "wss":
+                       hostPort += ":443"
+               case "https":
+                       hostPort += ":443"
+               default:
+                       hostPort += ":80"
+               }
+       }
+       return hostPort, hostNoPort
+}
+
+var ErrMalformedURL = errors.New("malformed ws or wss URL")
+
+var ErrInvalidCompression = errors.New("websocket: invalid compression negotiation")
+
+func TokenListContainsValue(header http.Header, name string, value string) bool {
+headers:
+       for _, s := range header[name] {
+               for {
+                       var t string
+                       t, s = nextToken(skipSpace(s))
+                       if t == "" {
+                               continue headers
+                       }
+                       s = skipSpace(s)
+                       if s != "" && s[0] != ',' {
+                               continue headers
+                       }
+                       if equalASCIIFold(t, value) {
+                               return true
+                       }
+                       if s == "" {
+                               continue headers
+                       }
+                       s = s[1:]
+               }
+       }
+       return false
+}
+func nextToken(s string) (token, rest string) {
+       i := 0
+       for ; i < len(s); i++ {
+               if !isTokenOctet[s[i]] {
+                       break
+               }
+       }
+       return s[:i], s[i:]
+}
+
+// Token octets per RFC 2616.
+var isTokenOctet = [256]bool{
+       '!':  true,
+       '#':  true,
+       '$':  true,
+       '%':  true,
+       '&':  true,
+       '\'': true,
+       '*':  true,
+       '+':  true,
+       '-':  true,
+       '.':  true,
+       '0':  true,
+       '1':  true,
+       '2':  true,
+       '3':  true,
+       '4':  true,
+       '5':  true,
+       '6':  true,
+       '7':  true,
+       '8':  true,
+       '9':  true,
+       'A':  true,
+       'B':  true,
+       'C':  true,
+       'D':  true,
+       'E':  true,
+       'F':  true,
+       'G':  true,
+       'H':  true,
+       'I':  true,
+       'J':  true,
+       'K':  true,
+       'L':  true,
+       'M':  true,
+       'N':  true,
+       'O':  true,
+       'P':  true,
+       'Q':  true,
+       'R':  true,
+       'S':  true,
+       'T':  true,
+       'U':  true,
+       'W':  true,
+       'V':  true,
+       'X':  true,
+       'Y':  true,
+       'Z':  true,
+       '^':  true,
+       '_':  true,
+       '`':  true,
+       'a':  true,
+       'b':  true,
+       'c':  true,
+       'd':  true,
+       'e':  true,
+       'f':  true,
+       'g':  true,
+       'h':  true,
+       'i':  true,
+       'j':  true,
+       'k':  true,
+       'l':  true,
+       'm':  true,
+       'n':  true,
+       'o':  true,
+       'p':  true,
+       'q':  true,
+       'r':  true,
+       's':  true,
+       't':  true,
+       'u':  true,
+       'v':  true,
+       'w':  true,
+       'x':  true,
+       'y':  true,
+       'z':  true,
+       '|':  true,
+       '~':  true,
+}
+
+func skipSpace(s string) (rest string) {
+       i := 0
+       for ; i < len(s); i++ {
+               if b := s[i]; b != ' ' && b != '\t' {
+                       break
+               }
+       }
+       return s[i:]
+}
+func equalASCIIFold(s, t string) bool {
+       for s != "" && t != "" {
+               sr, size := utf8.DecodeRuneInString(s)
+               s = s[size:]
+               tr, size := utf8.DecodeRuneInString(t)
+               t = t[size:]
+               if sr == tr {
+                       continue
+               }
+               if 'A' <= sr && sr <= 'Z' {
+                       sr = sr + 'a' - 'A'
+               }
+               if 'A' <= tr && tr <= 'Z' {
+                       tr = tr + 'a' - 'A'
+               }
+               if sr != tr {
+                       return false
+               }
+       }
+       return s == t
+}
diff --git a/ws.go b/ws.go
index 8fd93a9047f748a8901654fa2a32350054afcc11..70875a0f1c7654159df4b9a833136dfec612b509 100644 (file)
--- a/ws.go
+++ b/ws.go
@@ -18,6 +18,7 @@ import (
        _ "unsafe"
 
        "github.com/gorilla/websocket"
+       utils "github.com/qydysky/front/utils"
        pctx "github.com/qydysky/part/ctx"
        pslice "github.com/qydysky/part/slice"
        "golang.org/x/net/proxy"
@@ -107,7 +108,7 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
                        MaxAge: chosenBack.Splicing(),
                        Path:   "/",
                }
-               if validCookieDomain(r.Host) {
+               if utils.ValidCookieDomain(r.Host) {
                        cookie.Domain = r.Host
                }
                w.Header().Add("Set-Cookie", (cookie).String())
@@ -182,12 +183,12 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
        case "wss":
                u.Scheme = "https"
        default:
-               return nil, nil, errMalformedURL
+               return nil, nil, utils.ErrMalformedURL
        }
 
        if u.User != nil {
                // User name and password are not allowed in websocket URIs.
-               return nil, nil, errMalformedURL
+               return nil, nil, utils.ErrMalformedURL
        }
 
        req := &http.Request{
@@ -240,7 +241,7 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
                        netDial = d.NetDial
                }
        default:
-               return nil, nil, errMalformedURL
+               return nil, nil, utils.ErrMalformedURL
        }
 
        if netDial == nil {
@@ -284,7 +285,7 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
                }
        }
 
-       hostPort, hostNoPort := hostPortNoPort(u)
+       hostPort, hostNoPort := utils.HostPortNoPort(u)
        trace := httptrace.ContextClientTrace(ctx)
        if trace != nil && trace.GetConn != nil {
                trace.GetConn(hostPort)
@@ -303,7 +304,7 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
        if u.Scheme == "https" && d.NetDialTLSContext == nil {
                // If NetDialTLSContext is set, assume that the TLS handshake has already been done
 
-               cfg := cloneTLSConfig(d.TLSClientConfig)
+               cfg := utils.CloneTLSConfig(d.TLSClientConfig)
                if cfg.ServerName == "" {
                        cfg.ServerName = hostNoPort
                }
@@ -337,7 +338,7 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
                if trace != nil && trace.TLSHandshakeStart != nil {
                        trace.TLSHandshakeStart()
                }
-               err := doHandshake(ctx, tlsConn, cfg)
+               err := utils.DoHandshake(ctx, tlsConn, cfg)
                if trace != nil && trace.TLSHandshakeDone != nil {
                        trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
                }
@@ -388,9 +389,9 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header,
        }
 
        if resp.StatusCode != 101 ||
-               !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
-               !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
-               resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+               !utils.TokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+               !utils.TokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+               resp.Header.Get("Sec-Websocket-Accept") != utils.ComputeAcceptKey(challengeKey) {
                // Before closing the network connection on return from this
                // function, slurp up some of the response to aid application
                // debugging.
@@ -415,51 +416,6 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
        return fn(network, addr)
 }
 
-//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
-func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
-
-//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
-func cloneTLSConfig(cfg *tls.Config) *tls.Config
-
-//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
-func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
-
-//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
-var errMalformedURL error
-
-//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
-var errInvalidCompression error
-
-//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
-func generateChallengeKey() (string, error)
-
-//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
-func tokenListContainsValue(header http.Header, name string, value string) bool
-
-//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
-// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
-
-//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
-func checkSameOrigin(r *http.Request) bool
-
-//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
-func isValidChallengeKey(s string) bool
-
-//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
-func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
-
-//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
-func parseExtensions(header http.Header) []map[string]string
-
-//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
-func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
-
-//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
-func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
-
-//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
-func computeAcceptKey(challengeKey string) string
-
 const (
        maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
        defaultReadBufferSize      = 4096
@@ -486,7 +442,7 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header)
                return nil, errors.New("websocket: client sent data before handshake is complete")
        }
 
-       buf := bufioWriterBuffer(netConn, brw.Writer)
+       buf := utils.BufioWriterBuffer(netConn, brw.Writer)
 
        var writeBuf []byte
        if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {