From d562536f2949781a62ceb2901a2537d43c224bca Mon Sep 17 00:00:00 2001 From: qydysky Date: Sun, 10 Dec 2023 03:54:44 +0800 Subject: [PATCH] 1 --- config.go | 71 +++++++++++++++++++++++------- main.go | 114 +++++++++++++++++++++++++++++++++++++------------ main/main.json | 3 +- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/config.go b/config.go index 4da6a38..559c195 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net/http" "sync" "time" ) @@ -24,6 +25,7 @@ type Config struct { type Route struct { Path string `json:"path"` Sign string `json:"-"` + Splicing int `json:"splicing"` ErrRedirect bool `json:"errRedirect"` Back []Back `json:"back"` } @@ -44,14 +46,20 @@ func (t *Route) GenBack() []*Back { var backLink []*Back for i := 0; i < len(t.Back); i++ { back := &t.Back[i] + back.SwapSign() + if back.Weight == 0 { + continue + } tmpBack := Back{ - Name: back.Name, - To: back.To, - Weight: back.Weight, - ErrBanSec: back.ErrBanSec, - PathAdd: back.PathAdd, - ReqHeader: append([]Header{}, back.ReqHeader...), - ResHeader: append([]Header{}, back.ResHeader...), + Name: back.Name, + Sign: back.Sign, + To: back.To, + Weight: back.Weight, + ErrBanSec: back.ErrBanSec, + PathAdd: back.PathAdd, + MatchHeader: append([]Header{}, back.MatchHeader...), + ReqHeader: append([]Header{}, back.ReqHeader...), + ResHeader: append([]Header{}, back.ResHeader...), } for i := 1; i <= back.Weight; i++ { backLink = append(backLink, &tmpBack) @@ -60,16 +68,44 @@ func (t *Route) GenBack() []*Back { return backLink } +func GetBackByRequest(backs []*Back, r *http.Request) []*Back { + var backLink []*Back + for i := 0; i < len(backs); i++ { + matchs := len(backs[i].MatchHeader) - 1 + for ; matchs >= 0 && + r.Header.Get(backs[i].MatchHeader[matchs].Key) == backs[i].MatchHeader[matchs].Value; matchs -= 1 { + } + if matchs == -1 { + backLink = append(backLink, backs[i]) + } + } + return backLink +} + type Back struct { - lock sync.RWMutex - upT time.Time - Name string `json:"name"` - To string `json:"to"` - Weight int `json:"weight"` - ErrBanSec int `json:"errBanSec"` - PathAdd bool `json:"pathAdd"` - ReqHeader []Header `json:"reqHeader"` - ResHeader []Header `json:"resHeader"` + lock sync.RWMutex + Sign string `json:"-"` + upT time.Time + Name string `json:"name"` + To string `json:"to"` + Weight int `json:"weight"` + ErrBanSec int `json:"errBanSec"` + PathAdd bool `json:"pathAdd"` + MatchHeader []Header `json:"matchHeader"` + ReqHeader []Header `json:"reqHeader"` + ResHeader []Header `json:"resHeader"` +} + +func (t *Back) SwapSign() bool { + data, _ := json.Marshal(t) + w := md5.New() + w.Write(data) + sign := fmt.Sprintf("%x", w.Sum(nil)) + if t.Sign != sign { + t.Sign = sign + return true + } + return false } func (t *Back) IsLive() bool { @@ -79,6 +115,9 @@ func (t *Back) IsLive() bool { } func (t *Back) Disable() { + if t.ErrBanSec == 0 { + return + } t.lock.Lock() defer t.lock.Unlock() t.upT = time.Now().Add(time.Second * time.Duration(t.ErrBanSec)) diff --git a/main.go b/main.go index 4cd4daf..9209545 100644 --- a/main.go +++ b/main.go @@ -8,8 +8,8 @@ import ( "fmt" "io" "net/http" - "strings" "time" + _ "unsafe" "github.com/gorilla/websocket" pctx "github.com/qydysky/part/ctx" @@ -72,6 +72,8 @@ func Test(ctx context.Context, port int, logger Logger) { <-ctx1.Done() } +var cookie = fmt.Sprintf("%p", &struct{}{}) + // 转发 func Run(ctx context.Context, configSP *Config, logger Logger) { // 根ctx @@ -165,6 +167,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log for i := 0; i < len(configS.Routes); i++ { route := &configS.Routes[i] path := route.Path + splicing := route.Splicing ErrRedirect := route.ErrRedirect if !route.SwapSign() { @@ -185,47 +188,86 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log continue } + backMap := make(map[string]*Back) + + for i := 0; i < len(backArray); i++ { + backMap[backArray[i].Sign] = backArray[i] + } + logger.Info(`I:`, "路由更新", path) routeP.Store(path, func(w http.ResponseWriter, r *http.Request) { ctx1, done1 := pctx.WaitCtx(ctx) defer done1() - now := time.Now() - backI := now.UnixMilli() % int64(len(backArray)) - - if !backArray[backI].IsLive() { - for backI = 0; backI < int64(len(backArray)); backI++ { - if backArray[backI].IsLive() { - break + var backI *Back + validHost := validCookieDomain(r.Host) + if validHost { + if t, e := r.Cookie("_psign_" + cookie); e == nil { + if tmp, ok := backMap[t.Value]; ok { + backI = tmp } } - if backI == int64(len(backArray)) { + } + + if backI == nil { + backArray := GetBackByRequest(backArray, r) + switch len(backArray) { + case 0: w.WriteHeader(http.StatusServiceUnavailable) - logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path)) + logger.Error(`W:`, fmt.Sprintf("%s=> 无匹配", path)) return + case 1: + backI = backArray[0] + default: + backI = backArray[time.Now().UnixMilli()%int64(len(backArray))] + } + if !backI.IsLive() { + backI = nil + for i := 0; i < len(backArray); i++ { + if backArray[i].IsLive() { + backI = backArray[i] + break + } + } + if backI == nil { + w.WriteHeader(http.StatusServiceUnavailable) + logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path)) + return + } } } - logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backArray[backI].Name)) + if validHost { + w.Header().Add("Set-Cookie", (&http.Cookie{ + Name: "_psign_" + cookie, + Value: backI.Sign, + MaxAge: splicing, + Path: path, + }).String()) + } + + w.Header().Add("_pto_"+cookie, backI.Name) + + logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backI.Name)) var e error if r.Header.Get("Upgrade") == "websocket" { - e = wsDealer(ctx1, w, r, path, backArray[backI], logger) + e = wsDealer(ctx1, w, r, path, backI, logger) } else { - e = httpDealer(ctx1, w, r, path, backArray[backI], logger) + e = httpDealer(ctx1, w, r, path, backI, logger) } if e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backArray[backI].Name, e)) + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backI.Name, e)) switch e { case ErrCopy: - backArray[backI].Disable() + backI.Disable() return case ErrHeaderCheckFail: w.WriteHeader(http.StatusForbidden) return default: - backArray[backI].Disable() + backI.Disable() if ErrRedirect { w.Header().Set("Location", r.URL.String()) w.WriteHeader(http.StatusTemporaryRedirect) @@ -253,8 +295,14 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou url += r.URL.String() } - if !strings.HasPrefix(url, "http") { - return ErrNoHttp + url = "http" + url + + for _, v := range back.ReqHeader { + if v.Action == `check` { + if r.Header.Get(v.Key) != v.Value { + return ErrHeaderCheckFail + } + } } req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body) @@ -269,9 +317,6 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou for _, v := range back.ReqHeader { switch v.Action { case `check`: - if req.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail - } case `set`: req.Header.Set(v.Key, v.Value) case `add`: @@ -288,8 +333,13 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou return ErrReqDoFail } + header := w.Header() for k, v := range resp.Header { - w.Header().Set(k, v[0]) + if has(&header, k) { + header.Add(k, v[0]) + } else { + header.Set(k, v[0]) + } } for _, v := range back.ResHeader { @@ -299,11 +349,11 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou return ErrHeaderCheckFail } case `set`: - w.Header().Set(v.Key, v.Value) + header.Set(v.Key, v.Value) case `add`: - w.Header().Add(v.Key, v.Value) + header.Add(v.Key, v.Value) case `del`: - w.Header().Del(v.Key) + header.Del(v.Key) default: logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v)) } @@ -329,9 +379,7 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route url += r.URL.String() } - if !strings.HasPrefix(url, "ws") { - return ErrNoWs - } + url = "ws" + url reqHeader := make(http.Header) for _, v := range back.ReqHeader { @@ -353,6 +401,10 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil { return ErrReqDoFail } else { + resp.Header.Del("connection") + resp.Header.Del("upgrade") + resp.Header.Del("sec-websocket-accept") + for _, v := range back.ResHeader { switch v.Action { case `check`: @@ -410,3 +462,9 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route } } } + +//go:linkname has net/http.(*Header).has +func has(h *http.Header, key string) bool + +//go:linkname validCookieDomain net/http.validCookieDomain +func validCookieDomain(v string) bool diff --git a/main/main.json b/main/main.json index 7bb8f27..fdfe02b 100644 --- a/main/main.json +++ b/main/main.json @@ -16,7 +16,6 @@ "to": "http://127.0.0.1:13000", "weight": 1, "pathAdd": false, - "errBanSec": 10, "resHeader": [ { "action": "set", @@ -33,7 +32,7 @@ "matchRule": "prefix", "routes": [ { - "path": "/2", + "matcher": "/2", "errRedirect": true, "back": [ { -- 2.39.2