From c93d322b16625f26d2ac7de0db06756d49653003 Mon Sep 17 00:00:00 2001 From: qydysky Date: Sat, 6 Jan 2024 00:51:23 +0800 Subject: [PATCH] 1 --- config.go | 15 +-- main.go | 286 ++++++++++++++++++++++++++++--------------------- main/build.sh | 5 + main/main.json | 50 ++------- 4 files changed, 185 insertions(+), 171 deletions(-) create mode 100755 main/build.sh diff --git a/config.go b/config.go index 64fb27b..9a591c3 100644 --- a/config.go +++ b/config.go @@ -27,11 +27,10 @@ type Config struct { } type Route struct { - Path string `json:"path"` - Sign string `json:"-"` - Splicing int `json:"splicing"` - ErrRedirect bool `json:"errRedirect"` - Back []Back `json:"back"` + Path string `json:"path"` + Sign string `json:"-"` + Splicing int `json:"splicing"` + Back []Back `json:"back"` } func (t *Route) SwapSign() bool { @@ -56,6 +55,7 @@ func (t *Route) GenBack() []*Back { } tmpBack := Back{ Name: back.Name, + Splicing: t.Splicing, Sign: back.Sign, To: back.To, Weight: back.Weight, @@ -72,14 +72,14 @@ func (t *Route) GenBack() []*Back { return backLink } -func GetBackByRequest(backs []*Back, r *http.Request) []*Back { +func FiliterBackByRequest(backs []*Back, r *http.Request) []*Back { var backLink []*Back for i := 0; i < len(backs); i++ { matchs := len(backs[i].MatchHeader) - 1 for ; matchs >= 0 && r.Header.Get(backs[i].MatchHeader[matchs].Key) == backs[i].MatchHeader[matchs].Value; matchs -= 1 { } - if matchs == -1 { + if matchs == -1 && backs[i].IsLive() { backLink = append(backLink, backs[i]) } } @@ -89,6 +89,7 @@ func GetBackByRequest(backs []*Back, r *http.Request) []*Back { type Back struct { lock sync.RWMutex Sign string `json:"-"` + Splicing int `json:"-"` upT time.Time Name string `json:"name"` To string `json:"to"` diff --git a/main.go b/main.go index e513cf1..57c56e5 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "net/http" "net/http/httptrace" "net/url" + "strings" "time" _ "unsafe" @@ -188,8 +189,6 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log for i := 0; i < len(configS.Routes); i++ { route := &configS.Routes[i] path := route.Path - splicing := route.Splicing - ErrRedirect := route.ErrRedirect if !route.SwapSign() { continue @@ -221,80 +220,34 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log ctx1, done1 := pctx.WaitCtx(ctx) defer done1() - var backI *Back - validHost := validCookieDomain(r.Host) - if validHost { + var backIs []*Back + if validCookieDomain(r.Host) { if t, e := r.Cookie("_psign_" + cookie); e == nil { if tmp, ok := backMap[t.Value]; ok { - backI = tmp + backIs = append(backIs, tmp) } } } - if backI == nil { - backArray := GetBackByRequest(backArray, r) - switch len(backArray) { - case 0: - w.WriteHeader(http.StatusServiceUnavailable) - logger.Error(`W:`, fmt.Sprintf("%s=> 无匹配", path)) - return - case 1: - backI = backArray[0] - default: - backI = backArray[nanotime1()%int64(len(backArray))] - } - if !backI.IsLive() { - backI = nil - for i := 0; i < len(backArray); i++ { - if backArray[i].IsLive() { - backI = backArray[i] - break - } - } - if backI == nil { - w.WriteHeader(http.StatusServiceUnavailable) - logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path)) - return - } - } + if len(backIs) == 0 { + backIs = append(backIs, FiliterBackByRequest(backArray, r)...) } - if validHost { - w.Header().Add("Set-Cookie", (&http.Cookie{ - Name: "_psign_" + cookie, - Value: backI.Sign, - MaxAge: splicing, - Path: path, - Domain: r.Host, - }).String()) + if len(backIs) == 0 { + w.WriteHeader(http.StatusServiceUnavailable) + logger.Error(`W:`, fmt.Sprintf("%s=> 无可用后端", path)) + return } - w.Header().Add("_pto_"+cookie, backI.Name) - - logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backI.Name)) - var e error if r.Header.Get("Upgrade") == "websocket" { - e = wsDealer(ctx1, w, r, path, backI, logger, configS.BlocksI) + e = wsDealer(ctx1, w, r, path, backIs, logger, configS.BlocksI) } else { - e = httpDealer(ctx1, w, r, path, backI, logger, configS.BlocksI) + e = httpDealer(ctx1, w, r, path, backIs, logger, configS.BlocksI) } - if e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backI.Name, e)) - switch e { - case ErrCopy: - backI.Disable() - return - case ErrHeaderCheckFail: - w.WriteHeader(http.StatusForbidden) - return - default: - backI.Disable() - if ErrRedirect { - w.Header().Set("Location", r.URL.String()) - w.WriteHeader(http.StatusTemporaryRedirect) - } - } + if errors.Is(e, ErrHeaderCheckFail) { + w.WriteHeader(http.StatusForbidden) + return } }) } @@ -311,40 +264,73 @@ var ( ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail") ) -func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger Logger, blocksi pslice.BlocksI[byte]) error { - url := back.To - if back.PathAdd { - url += r.URL.String() - } +func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { + var ( + resp *http.Response + chosenBack *Back + ) - url = "http" + url + for 0 < len(backs) && resp == nil { + chosenBack = backs[0] + backs = backs[1:] - for _, v := range back.ReqHeader { - if v.Action == `check` { - if r.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail + url := chosenBack.To + if chosenBack.PathAdd { + url += r.URL.String() + } + + url = "http" + url + + for _, v := range chosenBack.ReqHeader { + if v.Action == `check` { + if r.Header.Get(v.Key) != v.Value { + return ErrHeaderCheckFail + } } } - } - req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body) - if e != nil { - return errors.Join(ErrReqCreFail, e) + req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body) + if e != nil { + return errors.Join(ErrReqCreFail, e) + } + + if e := copyHeader(r.Header, req.Header, chosenBack.ReqHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + return e + } + + req.Header.Del("Referer") + + client := http.Client{} + resp, e = client.Do(req) + if e != nil { + chosenBack.Disable() + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + } } - if e := copyHeader(r.Header, req.Header, back.ReqHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) - return e + if 0 == len(backs) && resp == nil { + logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name)) + return errors.New("全部后端故障") + } else if resp == nil { + return errors.New("后端故障") } - client := http.Client{} - resp, e := client.Do(req) - if e != nil { - return errors.Join(ErrReqDoFail, e) + logger.Error(`T:`, fmt.Sprintf("%s=>%s", routePath, chosenBack.Name)) + + if validCookieDomain(r.Host) { + w.Header().Add("Set-Cookie", (&http.Cookie{ + Name: "_psign_" + cookie, + Value: chosenBack.Sign, + MaxAge: chosenBack.Splicing, + Domain: r.Host, + }).String()) } - if e := copyHeader(resp.Header, w.Header(), back.ResHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + w.Header().Add("_pto_"+cookie, chosenBack.Name) + + if e := copyHeader(resp.Header, w.Header(), chosenBack.ResHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) return e } @@ -356,65 +342,105 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou defer resp.Body.Close() if tmpbuf, put, e := blocksi.Get(); e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + chosenBack.Disable() return errors.Join(ErrCopy, e) } else { defer put() if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + chosenBack.Disable() return errors.Join(ErrCopy, e) } } return nil } -func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger Logger, blocksi pslice.BlocksI[byte]) error { - url := back.To - if back.PathAdd { - url += r.URL.String() - } +func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { + var ( + resp *http.Response + conn net.Conn + chosenBack *Back + ) - url = "ws" + url + for 0 < len(backs) && resp == nil { + chosenBack = backs[0] + backs = backs[1:] - reqHeader := make(http.Header) + url := chosenBack.To + if chosenBack.PathAdd { + url += r.URL.String() + } - if e := copyHeader(r.Header, reqHeader, back.ReqHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) - return e - } + url = "ws" + url - if res, resp, e := DialContext(ctx, url, reqHeader); e != nil { - return errors.Join(ErrReqDoFail, e) - } else { - defer res.Close() + reqHeader := make(http.Header) - resHeader := make(http.Header) - if e := copyHeader(resp.Header, resHeader, back.ResHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + if e := copyHeader(r.Header, reqHeader, chosenBack.ReqHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) return e } - if req, e := Upgrade(w, r, resHeader); e != nil { - return errors.Join(ErrResDoFail, e) - } else { - defer req.Close() + reqHeader.Del("Referer") - select { - case e := <-copyWsMsg(req, res, blocksi): - if e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, back.Name, e)) - return errors.Join(ErrCopy, e) - } - case e := <-copyWsMsg(res, req, blocksi): - if e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, back.Name, e)) - return errors.Join(ErrCopy, e) - } - case <-ctx.Done(): - } + var e error + conn, resp, e = DialContext(ctx, url, reqHeader) + if e != nil { + chosenBack.Disable() + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + } + } + + if 0 == len(backs) && (resp == nil || conn == nil) { + logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name)) + return errors.New("全部后端故障") + } else if resp == nil || conn == nil { + return errors.New("后端故障") + } + + logger.Error(`T:`, fmt.Sprintf("%s=>%s", routePath, chosenBack.Name)) + + if validCookieDomain(r.Host) { + w.Header().Add("Set-Cookie", (&http.Cookie{ + Name: "_psign_" + cookie, + Value: chosenBack.Sign, + MaxAge: chosenBack.Splicing, + Domain: r.Host, + }).String()) + } + + w.Header().Add("_pto_"+cookie, chosenBack.Name) + + defer conn.Close() + + resHeader := make(http.Header) + if e := copyHeader(resp.Header, resHeader, chosenBack.ResHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) + return e + } + + if req, e := Upgrade(w, r, resHeader); e != nil { + return errors.Join(ErrResDoFail, e) + } else { + defer req.Close() - return nil + select { + case e := <-copyWsMsg(req, conn, blocksi): + if e != nil { + chosenBack.Disable() + logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, chosenBack.Name, e)) + return errors.Join(ErrCopy, e) + } + case e := <-copyWsMsg(conn, req, blocksi): + if e != nil { + chosenBack.Disable() + logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, chosenBack.Name, e)) + return errors.Join(ErrCopy, e) + } + case <-ctx.Done(): } + + return nil } } @@ -422,7 +448,18 @@ func copyHeader(s, t http.Header, app []Header) error { sm := (map[string][]string)(s) tm := (map[string][]string)(t) for k, v := range sm { - tm[k] = v + if strings.ToLower(k) == "set-cookie" { + cookies := strings.Split(v[0], ";") + for k, v := range cookies { + if strings.Contains(strings.ToLower(v), "domain=") { + cookies = append(cookies[:k], cookies[k+1:]...) + break + } + } + tm[k] = []string{strings.Join(cookies, ";")} + } else { + tm[k] = v + } } for _, v := range app { switch v.Action { @@ -666,6 +703,7 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + log.Default().Println(resp.StatusCode, resp.Header) return nil, resp, websocket.ErrBadHandshake } diff --git a/main/build.sh b/main/build.sh new file mode 100755 index 0000000..4201438 --- /dev/null +++ b/main/build.sh @@ -0,0 +1,5 @@ +#!/bin/sh +rm -rf front.* +CGO_ENABLED=0 go build -buildmode=exe -o front.run . +CGO_ENABLED=0 GOOS=windows go build -buildmode=exe -o front.exe . +echo ok \ No newline at end of file diff --git a/main/main.json b/main/main.json index 093a83d..c4ffeb8 100644 --- a/main/main.json +++ b/main/main.json @@ -1,54 +1,24 @@ [ { - "addr": "0.0.0.0:9009", - "tls": { - "pub": "cert.pem", - "key": "key.pem" - }, + "addr": "0.0.0.0:8081", "matchRule": "prefix", "copyBlocks": 100, "routes": [ { - "path": "/1", - "errRedirect": true, + "path": "/", + "errRedirect": false, "back": [ { - "name": "test", - "to": "://127.0.0.1:13000", + "name": "baidu1", + "to": "s://www.baidu.com", "weight": 1, - "pathAdd": false, - "resHeader": [ - { - "action": "set", - "key": "KEY", - "value": "asf" - } - ] - } - ] - } - ] - },{ - "addr": "0.0.0.0:9010", - "matchRule": "prefix", - "routes": [ - { - "matcher": "/2", - "errRedirect": true, - "back": [ + "pathAdd": true + }, { - "name": "test", - "to": "://127.0.0.1:13000", + "name": "baidu2", + "to": "s://www.baidu.com", "weight": 1, - "pathAdd": false, - "errBanSec": 10, - "resHeader": [ - { - "action": "set", - "key": "KEY", - "value": "asf" - } - ] + "pathAdd": true } ] } -- 2.39.2