]> 127.0.0.1 Git - front/.git/commitdiff
1 v0.1.20240309181159
authorqydysky <qydysky@foxmail.com>
Sat, 9 Mar 2024 18:11:40 +0000 (02:11 +0800)
committerqydysky <qydysky@foxmail.com>
Sat, 9 Mar 2024 18:11:40 +0000 (02:11 +0800)
README.md
config.go
go.mod
go.sum
http.go [new file with mode: 0644]
main.go
main/main.go
ws.go [new file with mode: 0644]

index 6d6ade19f694a907e89ffc3231f5971e17ccb5be..caa8acd4cac66eb622211f0cc670ed1a78b17c70 100755 (executable)
--- a/README.md
+++ b/README.md
@@ -36,21 +36,17 @@ config:
 
 matcher:
 
-- matchHeader: [] 匹配客户端请求头,只有都匹配才使用此后端, 可以动态增加/删除
-    - key: string 要匹配的header名
-    - matchExp: string 要匹配的正则式
-    - value: string 要匹配的值
 - reqHeader: [] 请求后端前,请求头处理器, 可以动态增加/删除
-    - action: string 可选check、replace、add、del、set。
+    - action: string 可选access、deny、replace、add、del、set。
     - key: string 具体处理哪个头
-    - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-    - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+    - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换
+    - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
 - resHeader: [] 返回后端的响应前,请求头处理器, 可以动态增加/删除
-    - action: string 可选check、add、del、set。
+    - action: string 可选access、deny、add、del、set。
     - key: string 具体处理哪个头
-    - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-    - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+    - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换
+    - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
 - reqBody: [] 请求后端前,请求数据过滤器, 可以动态增加/删除
-    - action: string 可选access,deny。
+    - action: string 可选accessdeny。
     - reqSize:string 限定请求数据大小,默认为"1M"
     - matchExp: string access时如不匹配将结束请求。deny时如匹配将结束请求。
index d295606cc7c03829a5542a027f2718625fb619f0..460538db4e4dbfe3f72f5509184ecf83c016d085 100755 (executable)
--- a/config.go
+++ b/config.go
@@ -8,6 +8,7 @@ import (
        "fmt"
        "io"
        "math/rand/v2"
+       "net"
        "net/http"
        "regexp"
        "sync"
@@ -31,6 +32,7 @@ type Config struct {
        CopyBlocks int                  `json:"copyBlocks"`
        BlocksI    pslice.BlocksI[byte] `json:"-"`
 
+       routeP   pweb.WebPath
        routeMap sync.Map `json:"-"`
        Routes   []Route  `json:"routes"`
 }
@@ -39,20 +41,21 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
        ctx, done := pctx.WithWait(ctx, 0, time.Minute)
        defer done()
 
-       routeP := pweb.WebPath{}
-
        var matchfunc func(path string) (func(w http.ResponseWriter, r *http.Request), bool)
        switch t.MatchRule {
        case "all":
-               matchfunc = routeP.Load
+               matchfunc = t.routeP.Load
        default:
-               matchfunc = routeP.LoadPerfix
+               matchfunc = t.routeP.LoadPerfix
        }
 
-       httpSer := http.Server{Addr: t.Addr}
+       httpSer := http.Server{
+               Addr:        t.Addr,
+               BaseContext: func(l net.Listener) context.Context { return ctx },
+       }
        if t.TLS.Key != "" && t.TLS.Pub != "" {
                if cert, e := tls.LoadX509KeyPair(t.TLS.Pub, t.TLS.Key); e != nil {
-                       logger.Error(`E:`, e)
+                       logger.Error(`E:`, fmt.Sprintf("%v %v", t.Addr, e))
                } else {
                        httpSer.TLSConfig = &tls.Config{
                                Certificates: []tls.Certificate{cert},
@@ -67,24 +70,29 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
                t.BlocksI = pslice.NewBlocks[byte](16*1024, t.CopyBlocks)
        }
 
-       syncWeb := pweb.NewSyncMap(&httpSer, &routeP, matchfunc)
+       syncWeb := pweb.NewSyncMap(&httpSer, &t.routeP, matchfunc)
        defer syncWeb.Shutdown()
 
-       var addRoute = func(k string, route *Route) {
-               logger.Info(`I:`, "路由加载", k)
-               t.routeMap.Store(k, route)
+       t.SwapSign(ctx, logger)
+       logger.Info(`I:`, fmt.Sprintf("%v running", t.Addr))
+       <-ctx.Done()
+       logger.Info(`I:`, fmt.Sprintf("%v shutdown", t.Addr))
+}
 
-               routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
-                       ctx1, done1 := pctx.WaitCtx(ctx)
-                       defer done1()
+func (t *Config) SwapSign(ctx context.Context, logger Logger) {
+       var add = func(k string, route *Route, logger Logger) {
+               route.config = t
+               logger.Info(`I:`, fmt.Sprintf("%v > %v", t.Addr, k))
+               t.routeMap.Store(k, route)
 
-                       if !HeaderMatchs(route.MatchHeader, r) {
+               t.routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
+                       if !HeaderMatchs(route.ReqHeader, r) {
                                w.WriteHeader(http.StatusNotFound)
                        }
 
                        var backIs []*Back
                        if t, e := r.Cookie("_psign_" + cookie); e == nil {
-                               if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).MatchHeader, r) {
+                               if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).ReqHeader, r) {
                                        backP.(*Back).PathAdd = route.PathAdd
                                        backP.(*Back).Splicing = route.Splicing
                                        backP.(*Back).tmp.ReqHeader = append(route.ReqHeader, backP.(*Back).ReqHeader...)
@@ -105,9 +113,12 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
 
                        var e error
                        if r.Header.Get("Upgrade") == "websocket" {
-                               e = wsDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI)
+                               e = wsDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI)
                        } else {
-                               e = httpDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI)
+                               e = httpDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI)
+                       }
+                       if e != nil {
+                               w.Header().Add(header+"Error", e.Error())
                        }
                        if errors.Is(e, ErrHeaderCheckFail) || errors.Is(e, ErrBodyCheckFail) {
                                w.WriteHeader(http.StatusForbidden)
@@ -116,25 +127,27 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
                })
        }
 
-       var delRoute = func(k string, route *Route) {
-               logger.Info(`I:`, "路由移除", k)
+       var del = func(k string, route *Route, logger Logger) {
+               logger.Info(`I:`, fmt.Sprintf("%v x %v", t.Addr, k))
                t.routeMap.Delete(k)
-               routeP.Store(k, nil)
+               t.routeP.Store(k, nil)
        }
 
-       t.SwapSign(addRoute, delRoute, logger)
-       logger.Info(`I:`, "启动完成")
-       for {
-               select {
-               case <-ctx.Done():
-                       return
-               case <-time.After(time.Second * 10):
-                       t.SwapSign(addRoute, delRoute, logger)
-               }
+       var routeU = func(route *Route, logger Logger) {
+               route.SwapSign(
+                       func(k string, b *Back) {
+                               b.route = route
+                               logger.Info(`I:`, fmt.Sprintf("%v > %v > %v", t.Addr, route.Path, b.Name))
+                               route.backMap.Store(k, b)
+                       },
+                       func(k string, b *Back) {
+                               logger.Info(`I:`, fmt.Sprintf("%v > %v x %v", t.Addr, route.Path, b.Name))
+                               route.backMap.Delete(k)
+                       },
+                       logger,
+               )
        }
-}
 
-func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), logger Logger) {
        t.routeMap.Range(func(key, value any) bool {
                var exist bool
                for k := 0; k < len(t.Routes); k++ {
@@ -144,34 +157,22 @@ func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), lo
                        }
                }
                if !exist {
-                       del(key.(string), value.(*Route))
+                       del(key.(string), value.(*Route), logger)
                }
                return true
        })
 
        for i := 0; i < len(t.Routes); i++ {
                if _, ok := t.routeMap.Load(t.Routes[i].Path); !ok {
-                       add(t.Routes[i].Path, &t.Routes[i])
+                       routeU(&t.Routes[i], logger)
+                       add(t.Routes[i].Path, &t.Routes[i], logger)
                }
        }
-
-       for i := 0; i < len(t.Routes); i++ {
-               t.Routes[i].SwapSign(
-                       func(k string, b *Back) {
-                               logger.Info(`I:`, "后端加载", t.Routes[i].Path, b.Name)
-                               t.Routes[i].backMap.Store(k, b)
-                       },
-                       func(k string, b *Back) {
-                               logger.Info(`I:`, "后端移除", t.Routes[i].Path, b.Name)
-                               t.Routes[i].backMap.Delete(k)
-                       },
-                       logger,
-               )
-       }
 }
 
 type Route struct {
-       Path string `json:"path"`
+       config *Config `json:"-"`
+       Path   string  `json:"path"`
 
        Splicing int  `json:"splicing"`
        PathAdd  bool `json:"pathAdd"`
@@ -237,7 +238,7 @@ func (t *Route) SwapSign(add func(string, *Back), del func(string, *Back), logge
 func (t *Route) FiliterBackByRequest(r *http.Request) []*Back {
        var backLink []*Back
        for i := 0; i < len(t.Backs); i++ {
-               if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].MatchHeader, r) {
+               if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].ReqHeader, r) {
                        t.Backs[i].PathAdd = t.PathAdd
                        t.Backs[i].Splicing = t.Splicing
                        t.Backs[i].tmp.ReqHeader = append(t.ReqHeader, t.Backs[i].ReqHeader...)
@@ -255,8 +256,9 @@ func (t *Route) FiliterBackByRequest(r *http.Request) []*Back {
 }
 
 type Back struct {
-       lock sync.RWMutex `json:"-"`
-       upT  time.Time    `json:"-"`
+       route *Route       `json:"-"`
+       lock  sync.RWMutex `json:"-"`
+       upT   time.Time    `json:"-"`
 
        Name      string `json:"name"`
        To        string `json:"to"`
@@ -310,10 +312,9 @@ func (t *Back) Disable() {
 }
 
 type Matcher struct {
-       MatchHeader []Header `json:"matchHeader"`
-       ReqHeader   []Header `json:"reqHeader"`
-       ResHeader   []Header `json:"resHeader"`
-       ReqBody     []Body   `json:"reqBody"`
+       ReqHeader []Header `json:"reqHeader"`
+       ResHeader []Header `json:"resHeader"`
+       ReqBody   []Body   `json:"reqBody"`
 }
 
 type Header struct {
@@ -324,15 +325,15 @@ type Header struct {
 }
 
 func (t *Header) Match(value string) bool {
-       if t.Value != "" && value != t.Value {
-               return false
+       if t.Action != "access" && t.Action != "deny" {
+               return true
        }
        if t.MatchExp != "" {
                if exp, e := regexp.Compile(t.MatchExp); e != nil || !exp.MatchString(value) {
-                       return false
+                       return t.Action == "deny"
                }
        }
-       return true
+       return t.Action == "access"
 }
 
 type Body struct {
diff --git a/go.mod b/go.mod
index 46c225815f5044dfc8d3505dfd33d7abee3a59b3..77a3d61384d86df64862d435fb89e1ccf935f839 100755 (executable)
--- a/go.mod
+++ b/go.mod
@@ -1,24 +1,30 @@
 module github.com/qydysky/front
 
-go 1.21
+go 1.22
 
 require (
+       github.com/dustin/go-humanize v1.0.1
        github.com/gorilla/websocket v1.5.1
-       github.com/qydysky/part v0.28.20240114140844
+       github.com/qydysky/part v0.28.20240309172046
        golang.org/x/net v0.18.0
 )
 
 require (
+       github.com/andybalholm/brotli v1.0.6 // indirect
        github.com/davecgh/go-spew v1.1.1 // indirect
-       github.com/dustin/go-humanize v1.0.1 // indirect
        github.com/go-ole/go-ole v1.3.0 // indirect
        github.com/google/uuid v1.4.0 // indirect
+       github.com/klauspost/compress v1.17.3 // indirect
+       github.com/miekg/dns v1.1.57 // indirect
        github.com/pmezard/go-difflib v1.0.0 // indirect
        github.com/shirou/gopsutil v3.21.11+incompatible // indirect
+       github.com/thedevsaddam/gojsonq/v2 v2.5.2 // indirect
        github.com/tklauser/go-sysconf v0.3.12 // indirect
        github.com/tklauser/numcpus v0.6.1 // indirect
        github.com/yusufpapurcu/wmi v1.2.3 // indirect
+       golang.org/x/mod v0.14.0 // indirect
        golang.org/x/sys v0.14.0 // indirect
        golang.org/x/text v0.14.0 // indirect
+       golang.org/x/tools v0.15.0 // indirect
        gopkg.in/yaml.v3 v3.0.1 // indirect
 )
diff --git a/go.sum b/go.sum
index 50a87c800c6aa32aa6f82466de5cf22371706212..1c27d4eecefcd3a6836f7c8ea2a7a17b0910adee 100755 (executable)
--- a/go.sum
+++ b/go.sum
@@ -13,20 +13,28 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/
 github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
+github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
 github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
 github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
+github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/qydysky/part v0.28.20240114140844 h1:9d0NPLx2CHKFca5Q6brECXeCjj0VXXaXmZRcJo6W85w=
-github.com/qydysky/part v0.28.20240114140844/go.mod h1:NyKyjpBCSjcHtKlC+fL5lCidm57UCnwEgufiBDs5yxA=
+github.com/qydysky/part v0.28.20240309114649 h1:b82WHpgNecfv/UX4makU9EIHJ4nLAP9dnXCb8S+sJt4=
+github.com/qydysky/part v0.28.20240309114649/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c=
+github.com/qydysky/part v0.28.20240309172046 h1:aw2Dv8VaP0p+IMkwJQlCNaz0ccJ6l8YUhu+y39kvQgU=
+github.com/qydysky/part v0.28.20240309172046/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c=
 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
 github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
 github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
 github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/thedevsaddam/gojsonq/v2 v2.5.2 h1:CoMVaYyKFsVj6TjU6APqAhAvC07hTI6IQen8PHzHYY0=
+github.com/thedevsaddam/gojsonq/v2 v2.5.2/go.mod h1:bv6Xa7kWy82uT0LnXPE2SzGqTj33TAEeR560MdJkiXs=
 github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
 github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
 github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
diff --git a/http.go b/http.go
new file mode 100644 (file)
index 0000000..51b1f58
--- /dev/null
+++ b/http.go
@@ -0,0 +1,111 @@
+package front
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "io"
+       "net/http"
+       "time"
+       _ "unsafe"
+
+       pslice "github.com/qydysky/part/slice"
+)
+
+func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+
+       var (
+               opT        = time.Now()
+               resp       *http.Response
+               chosenBack *Back
+       )
+
+       for 0 < len(backs) && resp == nil {
+               chosenBack = backs[0]
+               backs = backs[1:]
+
+               url := chosenBack.To
+               if chosenBack.PathAdd {
+                       url += r.URL.String()
+               }
+
+               url = "http" + url
+
+               reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+               if e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                       return errors.Join(ErrBodyCheckFail, e)
+               }
+
+               req, e := http.NewRequestWithContext(ctx, r.Method, url, reader)
+               if e != nil {
+                       return errors.Join(ErrReqCreFail, e)
+               }
+
+               if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                       return e
+               }
+
+               client := http.Client{
+                       CheckRedirect: func(req *http.Request, via []*http.Request) error {
+                               return ErrRedirect
+                       },
+               }
+               resp, e = client.Do(req)
+               if e != nil && !errors.Is(e, ErrRedirect) && !errors.Is(e, context.Canceled) {
+                       chosenBack.Disable()
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+               }
+       }
+
+       if 0 == len(backs) && resp == nil {
+               logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+               return errors.New("全部后端故障")
+       } else if resp == nil {
+               return errors.New("后端故障")
+       }
+
+       logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v http ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+
+       {
+               cookie := &http.Cookie{
+                       Name:   "_psign_" + cookie,
+                       Value:  chosenBack.Id(),
+                       MaxAge: chosenBack.Splicing,
+                       Path:   "/",
+               }
+               if validCookieDomain(r.Host) {
+                       cookie.Domain = r.Host
+               }
+               w.Header().Add("Set-Cookie", (cookie).String())
+       }
+
+       w.Header().Add(header+"Info", cookie+";"+chosenBack.Name)
+
+       if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil {
+               logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+               return e
+       }
+
+       w.WriteHeader(resp.StatusCode)
+
+       if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 {
+               return nil
+       }
+
+       defer resp.Body.Close()
+       if tmpbuf, put, e := blocksi.Get(); e != nil {
+               logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+               chosenBack.Disable()
+               return errors.Join(ErrCopy, e)
+       } else {
+               defer put()
+               if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil {
+                       logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                       chosenBack.Disable()
+                       return errors.Join(ErrCopy, e)
+               }
+       }
+       return nil
+}
diff --git a/main.go b/main.go
index fba1ca6a24abc61b8aaf1b9e2b63904d52a80ef6..5f0e1d00805d4d16c1282ce36f7cc4d0e491f07f 100755 (executable)
--- a/main.go
+++ b/main.go
@@ -1,29 +1,20 @@
 package front
 
 import (
-       "bufio"
-       "bytes"
        "context"
-       "crypto/tls"
+       "crypto/md5"
        "encoding/json"
        "errors"
        "fmt"
        "io"
-       "log"
-       "net"
        "net/http"
-       "net/http/httptrace"
-       "net/url"
        "regexp"
        "strings"
        "time"
        _ "unsafe"
 
-       "github.com/gorilla/websocket"
        pctx "github.com/qydysky/part/ctx"
-       pslice "github.com/qydysky/part/slice"
        pweb "github.com/qydysky/part/web"
-       "golang.org/x/net/proxy"
 )
 
 type Logger interface {
@@ -42,10 +33,17 @@ func validCookieDomain(v string) bool
 
 // 加载
 func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error {
-       if e := loadConfig(buf, configF, configS); e != nil {
+       var oldBufMd5 string
+
+       if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil {
                logger.Error(`E:`, "配置加载", e)
                return e
+       } else {
+               oldBufMd5 = bufMd5
        }
+
+       logger.Info(`I:`, "配置更新", oldBufMd5[:5])
+
        // 定时加载config
        go func() {
                ctx1, done1 := pctx.WaitCtx(ctx)
@@ -53,8 +51,11 @@ func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config
                for {
                        select {
                        case <-time.After(time.Second * 5):
-                               if e := loadConfig(buf, configF, configS); e != nil {
+                               if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil {
                                        logger.Error(`E:`, "配置加载", e)
+                               } else if bufMd5 != oldBufMd5 {
+                                       oldBufMd5 = bufMd5
+                                       logger.Info(`I:`, "配置更新", oldBufMd5[:5])
                                }
                        case <-ctx1.Done():
                                return
@@ -85,239 +86,32 @@ func Test(ctx context.Context, port int, logger Logger) {
        <-ctx1.Done()
 }
 
-var cookie = fmt.Sprintf("%p", &struct{}{})
-
-func loadConfig(buf []byte, configF File, configS *[]Config) error {
+func loadConfig(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) (md5k string, e error) {
+       defer func() {
+               if err := recover(); err != nil {
+                       logger.Error(`E:`, err)
+                       e = errors.New("read panic")
+               }
+       }()
        if i, e := configF.Read(buf); e != nil && !errors.Is(e, io.EOF) {
-               return e
+               return "", e
        } else if i == cap(buf) {
-               return errors.New(`buf full`)
+               return "", errors.New(`buf full`)
        } else {
-               for i := 0; i < len(*configS); i++ {
-                       (*configS)[i].lock.Lock()
-                       defer (*configS)[i].lock.Unlock()
-               }
-               if e := json.Unmarshal(buf[:i], configS); e != nil {
-                       return e
-               }
-       }
-       return nil
-}
+               w := md5.New()
+               w.Write(buf[:i])
+               md5k = fmt.Sprintf("%x", w.Sum(nil))
 
-//go:linkname nanotime1 runtime.nanotime1
-func nanotime1() int64
-
-var (
-       ErrRedirect        = errors.New("ErrRedirect")
-       ErrNoHttp          = errors.New("ErrNoHttp")
-       ErrNoWs            = errors.New("ErrNoWs")
-       ErrCopy            = errors.New("ErrCopy")
-       ErrReqCreFail      = errors.New("ErrReqCreFail")
-       ErrReqDoFail       = errors.New("ErrReqDoFail")
-       ErrResDoFail       = errors.New("ErrResDoFail")
-       ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
-       ErrBodyCheckFail   = errors.New("ErrBodyCheckFail")
-)
-
-func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
-       var (
-               opT        = time.Now()
-               resp       *http.Response
-               chosenBack *Back
-       )
-
-       for 0 < len(backs) && resp == nil {
-               chosenBack = backs[0]
-               backs = backs[1:]
-
-               url := chosenBack.To
-               if chosenBack.PathAdd {
-                       url += r.URL.String()
-               }
-
-               url = "http" + url
-
-               for _, v := range chosenBack.tmp.ReqHeader {
-                       if v.Action == `check` {
-                               if r.Header.Get(v.Key) != v.Value {
-                                       return ErrHeaderCheckFail
-                               }
-                       }
-               }
-
-               reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
-               if e != nil {
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
-                       return errors.Join(ErrBodyCheckFail, e)
-               }
-
-               req, e := http.NewRequestWithContext(ctx, r.Method, url, reader)
-               if e != nil {
-                       return errors.Join(ErrReqCreFail, e)
-               }
-
-               if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil {
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-                       return e
-               }
-
-               client := http.Client{
-                       CheckRedirect: func(req *http.Request, via []*http.Request) error {
-                               return ErrRedirect
-                       },
-               }
-               resp, e = client.Do(req)
-               if e != nil && !errors.Is(e, ErrRedirect) {
-                       chosenBack.Disable()
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-               }
-       }
-
-       if 0 == len(backs) && resp == nil {
-               logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
-               return errors.New("全部后端故障")
-       } else if resp == nil {
-               return errors.New("后端故障")
-       }
-
-       logger.Debug(`T:`, fmt.Sprintf("http %s=>%s %v", routePath, chosenBack.Name, time.Since(opT)))
-
-       {
-               cookie := &http.Cookie{
-                       Name:   "_psign_" + cookie,
-                       Value:  chosenBack.Id(),
-                       MaxAge: chosenBack.Splicing,
-                       Path:   "/",
-               }
-               if validCookieDomain(r.Host) {
-                       cookie.Domain = r.Host
-               }
-               w.Header().Add("Set-Cookie", (cookie).String())
-       }
-
-       w.Header().Add("_pto_"+cookie, chosenBack.Name)
-
-       if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil {
-               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-               return e
-       }
-
-       w.WriteHeader(resp.StatusCode)
-
-       if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 {
-               return nil
-       }
-
-       defer resp.Body.Close()
-       if tmpbuf, put, e := blocksi.Get(); e != nil {
-               logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-               chosenBack.Disable()
-               return errors.Join(ErrCopy, e)
-       } else {
-               defer put()
-               if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil {
-                       logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-                       chosenBack.Disable()
-                       return errors.Join(ErrCopy, e)
-               }
-       }
-       return nil
-}
-
-func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
-       var (
-               opT        = time.Now()
-               resp       *http.Response
-               conn       net.Conn
-               chosenBack *Back
-       )
-
-       for 0 < len(backs) && (resp == nil || conn == nil) {
-               chosenBack = backs[0]
-               backs = backs[1:]
-
-               _, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
-               if e != nil {
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
-                       return errors.Join(ErrBodyCheckFail, e)
-               }
-
-               url := chosenBack.To
-               if chosenBack.PathAdd {
-                       url += r.URL.String()
-               }
-
-               url = "ws" + url
-
-               reqHeader := make(http.Header)
-
-               if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil {
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-                       return e
-               }
-
-               conn, resp, e = DialContext(ctx, url, reqHeader)
-               if e != nil {
-                       chosenBack.Disable()
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-               }
-       }
-
-       if 0 == len(backs) && (resp == nil || conn == nil) {
-               logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
-               return errors.New("全部后端故障")
-       } else if resp == nil || conn == nil {
-               return errors.New("后端故障")
-       }
-
-       logger.Debug(`T:`, fmt.Sprintf("ws %s=>%s %v", routePath, chosenBack.Name, time.Since(opT)))
-
-       {
-               cookie := &http.Cookie{
-                       Name:   "_psign_" + cookie,
-                       Value:  chosenBack.Id(),
-                       MaxAge: chosenBack.Splicing,
-                       Path:   "/",
-               }
-               if validCookieDomain(r.Host) {
-                       cookie.Domain = r.Host
+               if e := json.Unmarshal(buf[:i], configS); e != nil {
+                       return md5k, e
                }
-               w.Header().Add("Set-Cookie", (cookie).String())
-       }
-
-       w.Header().Add("_pto_"+cookie, chosenBack.Name)
-
-       defer conn.Close()
-
-       resHeader := make(http.Header)
-       if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil {
-               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
-               return e
-       }
-
-       if req, e := Upgrade(w, r, resHeader); e != nil {
-               return errors.Join(ErrResDoFail, e)
-       } else {
-               defer req.Close()
-
-               select {
-               case e := <-copyWsMsg(req, conn, blocksi):
-                       if e != nil {
-                               chosenBack.Disable()
-                               logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, chosenBack.Name, e))
-                               return errors.Join(ErrCopy, e)
-                       }
-               case e := <-copyWsMsg(conn, req, blocksi):
-                       if e != nil {
-                               chosenBack.Disable()
-                               logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, chosenBack.Name, e))
-                               return errors.Join(ErrCopy, e)
-                       }
-               case <-ctx.Done():
+               for i := 0; i < len(*configS); i++ {
+                       (*configS)[i].lock.Lock()
+                       (*configS)[i].SwapSign(ctx, logger)
+                       (*configS)[i].lock.Unlock()
                }
-
-               return nil
        }
+       return md5k, nil
 }
 
 func copyHeader(s, t http.Header, app []Header) error {
@@ -339,8 +133,18 @@ func copyHeader(s, t http.Header, app []Header) error {
        }
        for _, v := range app {
                switch v.Action {
-               case `check`:
-                       if !v.Match(tm[v.Key][0]) {
+               case `deny`:
+                       if va, ok := tm[v.Key]; ok && len(va) != 0 {
+                               if exp, e := regexp.Compile(v.MatchExp); e == nil && exp.MatchString(va[0]) {
+                                       return ErrHeaderCheckFail
+                               }
+                       }
+               case `access`:
+                       if va, ok := tm[v.Key]; ok && len(va) != 0 {
+                               if exp, e := regexp.Compile(v.MatchExp); e != nil || !exp.MatchString(va[0]) {
+                                       return ErrHeaderCheckFail
+                               }
+                       } else {
                                return ErrHeaderCheckFail
                        }
                case `replace`:
@@ -359,409 +163,16 @@ func copyHeader(s, t http.Header, app []Header) error {
        return nil
 }
 
-func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error {
-       c := make(chan error, 1)
-       go func() {
-               if tmpbuf, put, e := blocksi.Get(); e != nil {
-                       c <- e
-               } else {
-                       defer put()
-                       _, e := io.CopyBuffer(dst, src, tmpbuf)
-                       c <- e
-               }
-       }()
-       return c
-}
-
-func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) {
-       d := websocket.DefaultDialer
-
-       challengeKey := requestHeader.Get("Sec-WebSocket-Key")
-
-       u, err := url.Parse(urlStr)
-       if err != nil {
-               return nil, nil, err
-       }
-
-       switch u.Scheme {
-       case "ws":
-               u.Scheme = "http"
-       case "wss":
-               u.Scheme = "https"
-       default:
-               return nil, nil, errMalformedURL
-       }
-
-       if u.User != nil {
-               // User name and password are not allowed in websocket URIs.
-               return nil, nil, errMalformedURL
-       }
-
-       req := &http.Request{
-               Method:     http.MethodGet,
-               URL:        u,
-               Proto:      "HTTP/1.1",
-               ProtoMajor: 1,
-               ProtoMinor: 1,
-               Header:     make(http.Header),
-               Host:       u.Host,
-       }
-       req = req.WithContext(ctx)
-
-       // Set the request headers using the capitalization for names and values in
-       // RFC examples. Although the capitalization shouldn't matter, there are
-       // servers that depend on it. The Header.Set method is not used because the
-       // method canonicalizes the header names.
-       for k, vs := range requestHeader {
-               req.Header[k] = vs
-       }
-
-       if d.HandshakeTimeout != 0 {
-               var cancel func()
-               ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
-               defer cancel()
-       }
-
-       // Get network dial function.
-       var netDial func(network, add string) (net.Conn, error)
-
-       switch u.Scheme {
-       case "http":
-               if d.NetDialContext != nil {
-                       netDial = func(network, addr string) (net.Conn, error) {
-                               return d.NetDialContext(ctx, network, addr)
-                       }
-               } else if d.NetDial != nil {
-                       netDial = d.NetDial
-               }
-       case "https":
-               if d.NetDialTLSContext != nil {
-                       netDial = func(network, addr string) (net.Conn, error) {
-                               return d.NetDialTLSContext(ctx, network, addr)
-                       }
-               } else if d.NetDialContext != nil {
-                       netDial = func(network, addr string) (net.Conn, error) {
-                               return d.NetDialContext(ctx, network, addr)
-                       }
-               } else if d.NetDial != nil {
-                       netDial = d.NetDial
-               }
-       default:
-               return nil, nil, errMalformedURL
-       }
-
-       if netDial == nil {
-               netDialer := &net.Dialer{}
-               netDial = func(network, addr string) (net.Conn, error) {
-                       return netDialer.DialContext(ctx, network, addr)
-               }
-       }
-
-       // If needed, wrap the dial function to set the connection deadline.
-       if deadline, ok := ctx.Deadline(); ok {
-               forwardDial := netDial
-               netDial = func(network, addr string) (net.Conn, error) {
-                       c, err := forwardDial(network, addr)
-                       if err != nil {
-                               return nil, err
-                       }
-                       err = c.SetDeadline(deadline)
-                       if err != nil {
-                               if err := c.Close(); err != nil {
-                                       log.Printf("websocket: failed to close network connection: %v", err)
-                               }
-                               return nil, err
-                       }
-                       return c, nil
-               }
-       }
-
-       // If needed, wrap the dial function to connect through a proxy.
-       if d.Proxy != nil {
-               proxyURL, err := d.Proxy(req)
-               if err != nil {
-                       return nil, nil, err
-               }
-               if proxyURL != nil {
-                       dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
-                       if err != nil {
-                               return nil, nil, err
-                       }
-                       netDial = dialer.Dial
-               }
-       }
-
-       hostPort, hostNoPort := hostPortNoPort(u)
-       trace := httptrace.ContextClientTrace(ctx)
-       if trace != nil && trace.GetConn != nil {
-               trace.GetConn(hostPort)
-       }
-
-       netConn, err := netDial("tcp", hostPort)
-       if err != nil {
-               return nil, nil, err
-       }
-       if trace != nil && trace.GotConn != nil {
-               trace.GotConn(httptrace.GotConnInfo{
-                       Conn: netConn,
-               })
-       }
-
-       if u.Scheme == "https" && d.NetDialTLSContext == nil {
-               // If NetDialTLSContext is set, assume that the TLS handshake has already been done
-
-               cfg := cloneTLSConfig(d.TLSClientConfig)
-               if cfg.ServerName == "" {
-                       cfg.ServerName = hostNoPort
-               }
-               tlsConn := tls.Client(netConn, cfg)
-               netConn = tlsConn
-
-               if trace != nil && trace.TLSHandshakeStart != nil {
-                       trace.TLSHandshakeStart()
-               }
-               err := doHandshake(ctx, tlsConn, cfg)
-               if trace != nil && trace.TLSHandshakeDone != nil {
-                       trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
-               }
-
-               if err != nil {
-                       return nil, nil, err
-               }
-       }
-
-       var br *bufio.Reader
-       if br == nil {
-               if d.ReadBufferSize == 0 {
-                       d.ReadBufferSize = defaultReadBufferSize
-               } else if d.ReadBufferSize < maxControlFramePayloadSize {
-                       // must be large enough for control frame
-                       d.ReadBufferSize = maxControlFramePayloadSize
-               }
-               br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
-       }
-
-       if err := req.Write(netConn); err != nil {
-               return nil, nil, err
-       }
-
-       if trace != nil && trace.GotFirstResponseByte != nil {
-               if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
-                       trace.GotFirstResponseByte()
-               }
-       }
-
-       resp, err := http.ReadResponse(br, req)
-       if err != nil {
-               if d.TLSClientConfig != nil {
-                       for _, proto := range d.TLSClientConfig.NextProtos {
-                               if proto != "http/1.1" {
-                                       return nil, nil, fmt.Errorf(
-                                               "websocket: protocol %q was given but is not supported;"+
-                                                       "sharing tls.Config with net/http Transport can cause this error: %w",
-                                               proto, err,
-                                       )
-                               }
-                       }
-               }
-               return nil, nil, err
-       }
-
-       if d.Jar != nil {
-               if rc := resp.Cookies(); len(rc) > 0 {
-                       d.Jar.SetCookies(u, rc)
-               }
-       }
-
-       if resp.StatusCode != 101 ||
-               !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
-               !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
-               resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
-               // Before closing the network connection on return from this
-               // function, slurp up some of the response to aid application
-               // debugging.
-               buf := make([]byte, 1024)
-               n, _ := io.ReadFull(resp.Body, buf)
-               resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
-               log.Default().Println(resp.StatusCode, resp.Header)
-               return nil, resp, websocket.ErrBadHandshake
-       }
-
-       resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
-
-       if err := netConn.SetDeadline(time.Time{}); err != nil {
-               return nil, nil, err
-       }
-       return netConn, resp, nil
-}
-
-type netDialerFunc func(network, addr string) (net.Conn, error)
-
-func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
-       return fn(network, addr)
-}
-
-//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
-func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
-
-//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
-func cloneTLSConfig(cfg *tls.Config) *tls.Config
-
-//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
-func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
-
-//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
-var errMalformedURL error
-
-//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
-var errInvalidCompression error
-
-//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
-func generateChallengeKey() (string, error)
-
-//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
-func tokenListContainsValue(header http.Header, name string, value string) bool
-
-//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
-// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
-
-//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
-func checkSameOrigin(r *http.Request) bool
-
-//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
-func isValidChallengeKey(s string) bool
-
-//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
-func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
-
-//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
-func parseExtensions(header http.Header) []map[string]string
-
-//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
-func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
-
-//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
-func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
-
-//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
-func computeAcceptKey(challengeKey string) string
-
-const (
-       maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
-       defaultReadBufferSize      = 4096
-       defaultWriteBufferSize     = 4096
-       maxControlFramePayloadSize = 125
+var cookie = fmt.Sprintf("%p", &struct{}{})
+var header = "X-Front-"
+var (
+       ErrRedirect        = errors.New("ErrRedirect")
+       ErrNoHttp          = errors.New("ErrNoHttp")
+       ErrNoWs            = errors.New("ErrNoWs")
+       ErrCopy            = errors.New("ErrCopy")
+       ErrReqCreFail      = errors.New("ErrReqCreFail")
+       ErrReqDoFail       = errors.New("ErrReqDoFail")
+       ErrResDoFail       = errors.New("ErrResDoFail")
+       ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
+       ErrBodyCheckFail   = errors.New("ErrBodyCheckFail")
 )
-
-func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
-       u := &websocket.Upgrader{}
-       h, ok := w.(http.Hijacker)
-       if !ok {
-               return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
-       }
-       var brw *bufio.ReadWriter
-       netConn, brw, err := h.Hijack()
-       if err != nil {
-               return returnError(u, w, r, http.StatusInternalServerError, err.Error())
-       }
-
-       if brw.Reader.Buffered() > 0 {
-               if err := netConn.Close(); err != nil {
-                       log.Printf("websocket: failed to close network connection: %v", err)
-               }
-               return nil, errors.New("websocket: client sent data before handshake is complete")
-       }
-
-       buf := bufioWriterBuffer(netConn, brw.Writer)
-
-       var writeBuf []byte
-       if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
-               // Reuse hijacked write buffer as connection buffer.
-               writeBuf = buf
-       } else {
-               if u.WriteBufferSize <= 0 {
-                       u.WriteBufferSize = defaultWriteBufferSize
-               }
-               u.WriteBufferSize += maxFrameHeaderSize
-               if writeBuf == nil && u.WriteBufferPool == nil {
-                       writeBuf = make([]byte, u.WriteBufferSize)
-               }
-       }
-
-       // Use larger of hijacked buffer and connection write buffer for header.
-       p := buf
-       if len(writeBuf) > len(p) {
-               p = writeBuf
-       }
-       p = p[:0]
-
-       p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
-       for k, vs := range responseHeader {
-               for _, v := range vs {
-                       p = append(p, k...)
-                       p = append(p, ": "...)
-                       for i := 0; i < len(v); i++ {
-                               b := v[i]
-                               if b <= 31 {
-                                       // prevent response splitting.
-                                       b = ' '
-                               }
-                               p = append(p, b)
-                       }
-                       p = append(p, "\r\n"...)
-               }
-       }
-       p = append(p, "\r\n"...)
-
-       // Clear deadlines set by HTTP server.
-       if err := netConn.SetDeadline(time.Time{}); err != nil {
-               if err := netConn.Close(); err != nil {
-                       log.Printf("websocket: failed to close network connection: %v", err)
-               }
-               return nil, err
-       }
-
-       if u.HandshakeTimeout > 0 {
-               if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
-                       if err := netConn.Close(); err != nil {
-                               log.Printf("websocket: failed to close network connection: %v", err)
-                       }
-                       return nil, err
-               }
-       }
-       if _, err = netConn.Write(p); err != nil {
-               if err := netConn.Close(); err != nil {
-                       log.Printf("websocket: failed to close network connection: %v", err)
-               }
-               return nil, err
-       }
-       if u.HandshakeTimeout > 0 {
-               if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
-                       if err := netConn.Close(); err != nil {
-                               log.Printf("websocket: failed to close network connection: %v", err)
-                       }
-                       return nil, err
-               }
-       }
-
-       return netConn, nil
-}
-
-func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) {
-       err := HandshakeError{message: reason}
-       if u.Error != nil {
-               u.Error(w, r, status, err)
-       } else {
-               w.Header().Set("Sec-Websocket-Version", "13")
-               http.Error(w, http.StatusText(status), status)
-       }
-       return nil, err
-}
-
-type HandshakeError struct {
-       message string
-}
-
-func (t HandshakeError) Error() string {
-       return t.message
-}
index 5508203d066ce3f71a4e7dc7396ad3dd5f4f6927..bcea7593a5a023386d3a0acbb2da89e5eb53684e 100755 (executable)
@@ -74,7 +74,7 @@ func main() {
        go pfront.Test(ctx, *testP, logger.Base("测试"))
 
        for i := 0; i < len(configS); i++ {
-               go configS[i].Run(ctx, logger.Base(configS[i].Addr))
+               go configS[i].Run(ctx, logger)
        }
 
        // ctrl+c退出
diff --git a/ws.go b/ws.go
new file mode 100644 (file)
index 0000000..061269c
--- /dev/null
+++ b/ws.go
@@ -0,0 +1,525 @@
+package front
+
+import (
+       "bufio"
+       "bytes"
+       "context"
+       "crypto/tls"
+       "errors"
+       "fmt"
+       "io"
+       "log"
+       "net"
+       "net/http"
+       "net/http/httptrace"
+       "net/url"
+       "time"
+       _ "unsafe"
+
+       "github.com/gorilla/websocket"
+       pslice "github.com/qydysky/part/slice"
+       "golang.org/x/net/proxy"
+)
+
+func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+       var (
+               opT        = time.Now()
+               resp       *http.Response
+               conn       net.Conn
+               chosenBack *Back
+       )
+
+       for 0 < len(backs) && (resp == nil || conn == nil) {
+               chosenBack = backs[0]
+               backs = backs[1:]
+
+               _, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+               if e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                       return errors.Join(ErrBodyCheckFail, e)
+               }
+
+               url := chosenBack.To
+               if chosenBack.PathAdd {
+                       url += r.URL.String()
+               }
+
+               url = "ws" + url
+
+               reqHeader := make(http.Header)
+
+               if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                       return e
+               }
+
+               conn, resp, e = DialContext(ctx, url, reqHeader)
+               if e != nil && !errors.Is(e, context.Canceled) {
+                       chosenBack.Disable()
+                       logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+               }
+       }
+
+       if 0 == len(backs) && (resp == nil || conn == nil) {
+               logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+               return errors.New("全部后端故障")
+       } else if resp == nil || conn == nil {
+               return errors.New("后端故障")
+       }
+
+       logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v ws ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+
+       {
+               cookie := &http.Cookie{
+                       Name:   "_psign_" + cookie,
+                       Value:  chosenBack.Id(),
+                       MaxAge: chosenBack.Splicing,
+                       Path:   "/",
+               }
+               if validCookieDomain(r.Host) {
+                       cookie.Domain = r.Host
+               }
+               w.Header().Add("Set-Cookie", (cookie).String())
+       }
+
+       w.Header().Add(header+"Info", cookie+";"+chosenBack.Name)
+
+       defer conn.Close()
+
+       resHeader := make(http.Header)
+       if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil {
+               logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+               return e
+       }
+
+       if req, e := Upgrade(w, r, resHeader); e != nil {
+               return errors.Join(ErrResDoFail, e)
+       } else {
+               defer req.Close()
+
+               select {
+               case e := <-copyWsMsg(req, conn, blocksi):
+                       if e != nil {
+                               chosenBack.Disable()
+                               logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                               return errors.Join(ErrCopy, e)
+                       }
+               case e := <-copyWsMsg(conn, req, blocksi):
+                       if e != nil {
+                               chosenBack.Disable()
+                               logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+                               return errors.Join(ErrCopy, e)
+                       }
+               case <-ctx.Done():
+               }
+
+               return nil
+       }
+}
+
+func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error {
+       c := make(chan error, 1)
+       go func() {
+               if tmpbuf, put, e := blocksi.Get(); e != nil {
+                       c <- e
+               } else {
+                       defer put()
+                       _, e := io.CopyBuffer(dst, src, tmpbuf)
+                       c <- e
+               }
+       }()
+       return c
+}
+
+func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) {
+       d := websocket.DefaultDialer
+
+       challengeKey := requestHeader.Get("Sec-WebSocket-Key")
+
+       u, err := url.Parse(urlStr)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       switch u.Scheme {
+       case "ws":
+               u.Scheme = "http"
+       case "wss":
+               u.Scheme = "https"
+       default:
+               return nil, nil, errMalformedURL
+       }
+
+       if u.User != nil {
+               // User name and password are not allowed in websocket URIs.
+               return nil, nil, errMalformedURL
+       }
+
+       req := &http.Request{
+               Method:     http.MethodGet,
+               URL:        u,
+               Proto:      "HTTP/1.1",
+               ProtoMajor: 1,
+               ProtoMinor: 1,
+               Header:     make(http.Header),
+               Host:       u.Host,
+       }
+       req = req.WithContext(ctx)
+
+       // Set the request headers using the capitalization for names and values in
+       // RFC examples. Although the capitalization shouldn't matter, there are
+       // servers that depend on it. The Header.Set method is not used because the
+       // method canonicalizes the header names.
+       for k, vs := range requestHeader {
+               req.Header[k] = vs
+       }
+
+       if d.HandshakeTimeout != 0 {
+               var cancel func()
+               ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+               defer cancel()
+       }
+
+       // Get network dial function.
+       var netDial func(network, add string) (net.Conn, error)
+
+       switch u.Scheme {
+       case "http":
+               if d.NetDialContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialContext(ctx, network, addr)
+                       }
+               } else if d.NetDial != nil {
+                       netDial = d.NetDial
+               }
+       case "https":
+               if d.NetDialTLSContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialTLSContext(ctx, network, addr)
+                       }
+               } else if d.NetDialContext != nil {
+                       netDial = func(network, addr string) (net.Conn, error) {
+                               return d.NetDialContext(ctx, network, addr)
+                       }
+               } else if d.NetDial != nil {
+                       netDial = d.NetDial
+               }
+       default:
+               return nil, nil, errMalformedURL
+       }
+
+       if netDial == nil {
+               netDialer := &net.Dialer{}
+               netDial = func(network, addr string) (net.Conn, error) {
+                       return netDialer.DialContext(ctx, network, addr)
+               }
+       }
+
+       // If needed, wrap the dial function to set the connection deadline.
+       if deadline, ok := ctx.Deadline(); ok {
+               forwardDial := netDial
+               netDial = func(network, addr string) (net.Conn, error) {
+                       c, err := forwardDial(network, addr)
+                       if err != nil {
+                               return nil, err
+                       }
+                       err = c.SetDeadline(deadline)
+                       if err != nil {
+                               if err := c.Close(); err != nil {
+                                       log.Printf("websocket: failed to close network connection: %v", err)
+                               }
+                               return nil, err
+                       }
+                       return c, nil
+               }
+       }
+
+       // If needed, wrap the dial function to connect through a proxy.
+       if d.Proxy != nil {
+               proxyURL, err := d.Proxy(req)
+               if err != nil {
+                       return nil, nil, err
+               }
+               if proxyURL != nil {
+                       dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
+                       if err != nil {
+                               return nil, nil, err
+                       }
+                       netDial = dialer.Dial
+               }
+       }
+
+       hostPort, hostNoPort := hostPortNoPort(u)
+       trace := httptrace.ContextClientTrace(ctx)
+       if trace != nil && trace.GetConn != nil {
+               trace.GetConn(hostPort)
+       }
+
+       netConn, err := netDial("tcp", hostPort)
+       if err != nil {
+               return nil, nil, err
+       }
+       if trace != nil && trace.GotConn != nil {
+               trace.GotConn(httptrace.GotConnInfo{
+                       Conn: netConn,
+               })
+       }
+
+       if u.Scheme == "https" && d.NetDialTLSContext == nil {
+               // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
+               cfg := cloneTLSConfig(d.TLSClientConfig)
+               if cfg.ServerName == "" {
+                       cfg.ServerName = hostNoPort
+               }
+               tlsConn := tls.Client(netConn, cfg)
+               netConn = tlsConn
+
+               if trace != nil && trace.TLSHandshakeStart != nil {
+                       trace.TLSHandshakeStart()
+               }
+               err := doHandshake(ctx, tlsConn, cfg)
+               if trace != nil && trace.TLSHandshakeDone != nil {
+                       trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+               }
+
+               if err != nil {
+                       return nil, nil, err
+               }
+       }
+
+       var br *bufio.Reader
+       if br == nil {
+               if d.ReadBufferSize == 0 {
+                       d.ReadBufferSize = defaultReadBufferSize
+               } else if d.ReadBufferSize < maxControlFramePayloadSize {
+                       // must be large enough for control frame
+                       d.ReadBufferSize = maxControlFramePayloadSize
+               }
+               br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
+       }
+
+       if err := req.Write(netConn); err != nil {
+               return nil, nil, err
+       }
+
+       if trace != nil && trace.GotFirstResponseByte != nil {
+               if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
+                       trace.GotFirstResponseByte()
+               }
+       }
+
+       resp, err := http.ReadResponse(br, req)
+       if err != nil {
+               if d.TLSClientConfig != nil {
+                       for _, proto := range d.TLSClientConfig.NextProtos {
+                               if proto != "http/1.1" {
+                                       return nil, nil, fmt.Errorf(
+                                               "websocket: protocol %q was given but is not supported;"+
+                                                       "sharing tls.Config with net/http Transport can cause this error: %w",
+                                               proto, err,
+                                       )
+                               }
+                       }
+               }
+               return nil, nil, err
+       }
+
+       if d.Jar != nil {
+               if rc := resp.Cookies(); len(rc) > 0 {
+                       d.Jar.SetCookies(u, rc)
+               }
+       }
+
+       if resp.StatusCode != 101 ||
+               !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+               !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+               resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+               // Before closing the network connection on return from this
+               // function, slurp up some of the response to aid application
+               // debugging.
+               buf := make([]byte, 1024)
+               n, _ := io.ReadFull(resp.Body, buf)
+               resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+               log.Default().Println(resp.StatusCode, resp.Header)
+               return nil, resp, websocket.ErrBadHandshake
+       }
+
+       resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
+
+       if err := netConn.SetDeadline(time.Time{}); err != nil {
+               return nil, nil, err
+       }
+       return netConn, resp, nil
+}
+
+type netDialerFunc func(network, addr string) (net.Conn, error)
+
+func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
+       return fn(network, addr)
+}
+
+//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
+func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
+
+//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
+func cloneTLSConfig(cfg *tls.Config) *tls.Config
+
+//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
+
+//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
+var errMalformedURL error
+
+//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
+var errInvalidCompression error
+
+//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
+func generateChallengeKey() (string, error)
+
+//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
+func tokenListContainsValue(header http.Header, name string, value string) bool
+
+//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
+// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
+
+//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
+func checkSameOrigin(r *http.Request) bool
+
+//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
+func isValidChallengeKey(s string) bool
+
+//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
+func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
+
+//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
+func parseExtensions(header http.Header) []map[string]string
+
+//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
+func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
+
+//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
+func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
+
+//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
+func computeAcceptKey(challengeKey string) string
+
+const (
+       maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
+       defaultReadBufferSize      = 4096
+       defaultWriteBufferSize     = 4096
+       maxControlFramePayloadSize = 125
+)
+
+func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
+       u := &websocket.Upgrader{}
+       h, ok := w.(http.Hijacker)
+       if !ok {
+               return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
+       }
+       var brw *bufio.ReadWriter
+       netConn, brw, err := h.Hijack()
+       if err != nil {
+               return returnError(u, w, r, http.StatusInternalServerError, err.Error())
+       }
+
+       if brw.Reader.Buffered() > 0 {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, errors.New("websocket: client sent data before handshake is complete")
+       }
+
+       buf := bufioWriterBuffer(netConn, brw.Writer)
+
+       var writeBuf []byte
+       if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
+               // Reuse hijacked write buffer as connection buffer.
+               writeBuf = buf
+       } else {
+               if u.WriteBufferSize <= 0 {
+                       u.WriteBufferSize = defaultWriteBufferSize
+               }
+               u.WriteBufferSize += maxFrameHeaderSize
+               if writeBuf == nil && u.WriteBufferPool == nil {
+                       writeBuf = make([]byte, u.WriteBufferSize)
+               }
+       }
+
+       // Use larger of hijacked buffer and connection write buffer for header.
+       p := buf
+       if len(writeBuf) > len(p) {
+               p = writeBuf
+       }
+       p = p[:0]
+
+       p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
+       for k, vs := range responseHeader {
+               for _, v := range vs {
+                       p = append(p, k...)
+                       p = append(p, ": "...)
+                       for i := 0; i < len(v); i++ {
+                               b := v[i]
+                               if b <= 31 {
+                                       // prevent response splitting.
+                                       b = ' '
+                               }
+                               p = append(p, b)
+                       }
+                       p = append(p, "\r\n"...)
+               }
+       }
+       p = append(p, "\r\n"...)
+
+       // Clear deadlines set by HTTP server.
+       if err := netConn.SetDeadline(time.Time{}); err != nil {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, err
+       }
+
+       if u.HandshakeTimeout > 0 {
+               if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
+                       if err := netConn.Close(); err != nil {
+                               log.Printf("websocket: failed to close network connection: %v", err)
+                       }
+                       return nil, err
+               }
+       }
+       if _, err = netConn.Write(p); err != nil {
+               if err := netConn.Close(); err != nil {
+                       log.Printf("websocket: failed to close network connection: %v", err)
+               }
+               return nil, err
+       }
+       if u.HandshakeTimeout > 0 {
+               if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
+                       if err := netConn.Close(); err != nil {
+                               log.Printf("websocket: failed to close network connection: %v", err)
+                       }
+                       return nil, err
+               }
+       }
+
+       return netConn, nil
+}
+
+func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) {
+       err := HandshakeError{message: reason}
+       if u.Error != nil {
+               u.Error(w, r, status, err)
+       } else {
+               w.Header().Set("Sec-Websocket-Version", "13")
+               http.Error(w, http.StatusText(status), status)
+       }
+       return nil, err
+}
+
+type HandshakeError struct {
+       message string
+}
+
+func (t HandshakeError) Error() string {
+       return t.message
+}