]> 127.0.0.1 Git - front/.git/commitdiff
1 v0.1.20231209195503
authorqydysky <qydysky@foxmail.com>
Sat, 9 Dec 2023 19:54:44 +0000 (03:54 +0800)
committerqydysky <qydysky@foxmail.com>
Sat, 9 Dec 2023 19:54:44 +0000 (03:54 +0800)
config.go
main.go
main/main.json

index 4da6a3818dd73a68b1e2ffe1148b56ddd2babe7f..559c1957a9ce8b1d775f430a18eca450cb55fbf3 100644 (file)
--- a/config.go
+++ b/config.go
@@ -5,6 +5,7 @@ import (
        "crypto/tls"
        "encoding/json"
        "fmt"
+       "net/http"
        "sync"
        "time"
 )
@@ -24,6 +25,7 @@ type Config struct {
 type Route struct {
        Path        string `json:"path"`
        Sign        string `json:"-"`
+       Splicing    int    `json:"splicing"`
        ErrRedirect bool   `json:"errRedirect"`
        Back        []Back `json:"back"`
 }
@@ -44,14 +46,20 @@ func (t *Route) GenBack() []*Back {
        var backLink []*Back
        for i := 0; i < len(t.Back); i++ {
                back := &t.Back[i]
+               back.SwapSign()
+               if back.Weight == 0 {
+                       continue
+               }
                tmpBack := Back{
-                       Name:      back.Name,
-                       To:        back.To,
-                       Weight:    back.Weight,
-                       ErrBanSec: back.ErrBanSec,
-                       PathAdd:   back.PathAdd,
-                       ReqHeader: append([]Header{}, back.ReqHeader...),
-                       ResHeader: append([]Header{}, back.ResHeader...),
+                       Name:        back.Name,
+                       Sign:        back.Sign,
+                       To:          back.To,
+                       Weight:      back.Weight,
+                       ErrBanSec:   back.ErrBanSec,
+                       PathAdd:     back.PathAdd,
+                       MatchHeader: append([]Header{}, back.MatchHeader...),
+                       ReqHeader:   append([]Header{}, back.ReqHeader...),
+                       ResHeader:   append([]Header{}, back.ResHeader...),
                }
                for i := 1; i <= back.Weight; i++ {
                        backLink = append(backLink, &tmpBack)
@@ -60,16 +68,44 @@ func (t *Route) GenBack() []*Back {
        return backLink
 }
 
+func GetBackByRequest(backs []*Back, r *http.Request) []*Back {
+       var backLink []*Back
+       for i := 0; i < len(backs); i++ {
+               matchs := len(backs[i].MatchHeader) - 1
+               for ; matchs >= 0 &&
+                       r.Header.Get(backs[i].MatchHeader[matchs].Key) == backs[i].MatchHeader[matchs].Value; matchs -= 1 {
+               }
+               if matchs == -1 {
+                       backLink = append(backLink, backs[i])
+               }
+       }
+       return backLink
+}
+
 type Back struct {
-       lock      sync.RWMutex
-       upT       time.Time
-       Name      string   `json:"name"`
-       To        string   `json:"to"`
-       Weight    int      `json:"weight"`
-       ErrBanSec int      `json:"errBanSec"`
-       PathAdd   bool     `json:"pathAdd"`
-       ReqHeader []Header `json:"reqHeader"`
-       ResHeader []Header `json:"resHeader"`
+       lock        sync.RWMutex
+       Sign        string `json:"-"`
+       upT         time.Time
+       Name        string   `json:"name"`
+       To          string   `json:"to"`
+       Weight      int      `json:"weight"`
+       ErrBanSec   int      `json:"errBanSec"`
+       PathAdd     bool     `json:"pathAdd"`
+       MatchHeader []Header `json:"matchHeader"`
+       ReqHeader   []Header `json:"reqHeader"`
+       ResHeader   []Header `json:"resHeader"`
+}
+
+func (t *Back) SwapSign() bool {
+       data, _ := json.Marshal(t)
+       w := md5.New()
+       w.Write(data)
+       sign := fmt.Sprintf("%x", w.Sum(nil))
+       if t.Sign != sign {
+               t.Sign = sign
+               return true
+       }
+       return false
 }
 
 func (t *Back) IsLive() bool {
@@ -79,6 +115,9 @@ func (t *Back) IsLive() bool {
 }
 
 func (t *Back) Disable() {
+       if t.ErrBanSec == 0 {
+               return
+       }
        t.lock.Lock()
        defer t.lock.Unlock()
        t.upT = time.Now().Add(time.Second * time.Duration(t.ErrBanSec))
diff --git a/main.go b/main.go
index 4cd4daf7cd946c216819e0daa9b93ab45c098466..9209545c0ad52aefeb29ea4936fed242205c0832 100644 (file)
--- a/main.go
+++ b/main.go
@@ -8,8 +8,8 @@ import (
        "fmt"
        "io"
        "net/http"
-       "strings"
        "time"
+       _ "unsafe"
 
        "github.com/gorilla/websocket"
        pctx "github.com/qydysky/part/ctx"
@@ -72,6 +72,8 @@ func Test(ctx context.Context, port int, logger Logger) {
        <-ctx1.Done()
 }
 
+var cookie = fmt.Sprintf("%p", &struct{}{})
+
 // 转发
 func Run(ctx context.Context, configSP *Config, logger Logger) {
        // 根ctx
@@ -165,6 +167,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log
        for i := 0; i < len(configS.Routes); i++ {
                route := &configS.Routes[i]
                path := route.Path
+               splicing := route.Splicing
                ErrRedirect := route.ErrRedirect
 
                if !route.SwapSign() {
@@ -185,47 +188,86 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log
                        continue
                }
 
+               backMap := make(map[string]*Back)
+
+               for i := 0; i < len(backArray); i++ {
+                       backMap[backArray[i].Sign] = backArray[i]
+               }
+
                logger.Info(`I:`, "路由更新", path)
 
                routeP.Store(path, func(w http.ResponseWriter, r *http.Request) {
                        ctx1, done1 := pctx.WaitCtx(ctx)
                        defer done1()
 
-                       now := time.Now()
-                       backI := now.UnixMilli() % int64(len(backArray))
-
-                       if !backArray[backI].IsLive() {
-                               for backI = 0; backI < int64(len(backArray)); backI++ {
-                                       if backArray[backI].IsLive() {
-                                               break
+                       var backI *Back
+                       validHost := validCookieDomain(r.Host)
+                       if validHost {
+                               if t, e := r.Cookie("_psign_" + cookie); e == nil {
+                                       if tmp, ok := backMap[t.Value]; ok {
+                                               backI = tmp
                                        }
                                }
-                               if backI == int64(len(backArray)) {
+                       }
+
+                       if backI == nil {
+                               backArray := GetBackByRequest(backArray, r)
+                               switch len(backArray) {
+                               case 0:
                                        w.WriteHeader(http.StatusServiceUnavailable)
-                                       logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path))
+                                       logger.Error(`W:`, fmt.Sprintf("%s=> 无匹配", path))
                                        return
+                               case 1:
+                                       backI = backArray[0]
+                               default:
+                                       backI = backArray[time.Now().UnixMilli()%int64(len(backArray))]
+                               }
+                               if !backI.IsLive() {
+                                       backI = nil
+                                       for i := 0; i < len(backArray); i++ {
+                                               if backArray[i].IsLive() {
+                                                       backI = backArray[i]
+                                                       break
+                                               }
+                                       }
+                                       if backI == nil {
+                                               w.WriteHeader(http.StatusServiceUnavailable)
+                                               logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path))
+                                               return
+                                       }
                                }
                        }
 
-                       logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backArray[backI].Name))
+                       if validHost {
+                               w.Header().Add("Set-Cookie", (&http.Cookie{
+                                       Name:   "_psign_" + cookie,
+                                       Value:  backI.Sign,
+                                       MaxAge: splicing,
+                                       Path:   path,
+                               }).String())
+                       }
+
+                       w.Header().Add("_pto_"+cookie, backI.Name)
+
+                       logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backI.Name))
 
                        var e error
                        if r.Header.Get("Upgrade") == "websocket" {
-                               e = wsDealer(ctx1, w, r, path, backArray[backI], logger)
+                               e = wsDealer(ctx1, w, r, path, backI, logger)
                        } else {
-                               e = httpDealer(ctx1, w, r, path, backArray[backI], logger)
+                               e = httpDealer(ctx1, w, r, path, backI, logger)
                        }
                        if e != nil {
-                               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backArray[backI].Name, e))
+                               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backI.Name, e))
                                switch e {
                                case ErrCopy:
-                                       backArray[backI].Disable()
+                                       backI.Disable()
                                        return
                                case ErrHeaderCheckFail:
                                        w.WriteHeader(http.StatusForbidden)
                                        return
                                default:
-                                       backArray[backI].Disable()
+                                       backI.Disable()
                                        if ErrRedirect {
                                                w.Header().Set("Location", r.URL.String())
                                                w.WriteHeader(http.StatusTemporaryRedirect)
@@ -253,8 +295,14 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
                url += r.URL.String()
        }
 
-       if !strings.HasPrefix(url, "http") {
-               return ErrNoHttp
+       url = "http" + url
+
+       for _, v := range back.ReqHeader {
+               if v.Action == `check` {
+                       if r.Header.Get(v.Key) != v.Value {
+                               return ErrHeaderCheckFail
+                       }
+               }
        }
 
        req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
@@ -269,9 +317,6 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
        for _, v := range back.ReqHeader {
                switch v.Action {
                case `check`:
-                       if req.Header.Get(v.Key) != v.Value {
-                               return ErrHeaderCheckFail
-                       }
                case `set`:
                        req.Header.Set(v.Key, v.Value)
                case `add`:
@@ -288,8 +333,13 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
                return ErrReqDoFail
        }
 
+       header := w.Header()
        for k, v := range resp.Header {
-               w.Header().Set(k, v[0])
+               if has(&header, k) {
+                       header.Add(k, v[0])
+               } else {
+                       header.Set(k, v[0])
+               }
        }
 
        for _, v := range back.ResHeader {
@@ -299,11 +349,11 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
                                return ErrHeaderCheckFail
                        }
                case `set`:
-                       w.Header().Set(v.Key, v.Value)
+                       header.Set(v.Key, v.Value)
                case `add`:
-                       w.Header().Add(v.Key, v.Value)
+                       header.Add(v.Key, v.Value)
                case `del`:
-                       w.Header().Del(v.Key)
+                       header.Del(v.Key)
                default:
                        logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
                }
@@ -329,9 +379,7 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
                url += r.URL.String()
        }
 
-       if !strings.HasPrefix(url, "ws") {
-               return ErrNoWs
-       }
+       url = "ws" + url
 
        reqHeader := make(http.Header)
        for _, v := range back.ReqHeader {
@@ -353,6 +401,10 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
        if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil {
                return ErrReqDoFail
        } else {
+               resp.Header.Del("connection")
+               resp.Header.Del("upgrade")
+               resp.Header.Del("sec-websocket-accept")
+
                for _, v := range back.ResHeader {
                        switch v.Action {
                        case `check`:
@@ -410,3 +462,9 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
                }
        }
 }
+
+//go:linkname has net/http.(*Header).has
+func has(h *http.Header, key string) bool
+
+//go:linkname validCookieDomain net/http.validCookieDomain
+func validCookieDomain(v string) bool
index 7bb8f27e000d7aae237c2372cbcb0ca18fe7efa5..fdfe02be5eb447dfb8a256291c93c7f28e2d2df3 100644 (file)
@@ -16,7 +16,6 @@
             "to": "http://127.0.0.1:13000",
             "weight": 1,
             "pathAdd": false,
-            "errBanSec": 10,
             "resHeader": [
               {
                 "action": "set",
@@ -33,7 +32,7 @@
     "matchRule": "prefix",
     "routes": [
       {
-        "path": "/2",
+        "matcher": "/2",
         "errRedirect": true,
         "back": [
           {