From b1f00973fbd92055fae005f9008c232c75cacbef Mon Sep 17 00:00:00 2001 From: qydysky Date: Sun, 3 Dec 2023 02:59:00 +0800 Subject: [PATCH] 1 --- config.go | 4 ++ go.mod | 1 + go.sum | 2 + main.go | 205 ++++++++++++++++++++++++++++++++++++++++-------------- main.json | 2 + 5 files changed, 161 insertions(+), 53 deletions(-) diff --git a/config.go b/config.go index 6da6af6..7edf81a 100644 --- a/config.go +++ b/config.go @@ -36,8 +36,10 @@ func (t *Route) GenBack() []*Back { var backLink []*Back for _, back := range t.Back { tmpBack := Back{ + Name: back.Name, To: back.To, Weight: back.Weight, + PathAdd: back.PathAdd, ReqHeader: append([]Header{}, back.ReqHeader...), ResHeader: append([]Header{}, back.ResHeader...), } @@ -49,8 +51,10 @@ func (t *Route) GenBack() []*Back { } type Back struct { + Name string `json:"name"` To string `json:"to"` Weight int `json:"weight"` + PathAdd bool `json:"pathAdd"` ReqHeader []Header `json:"reqHeader"` ResHeader []Header `json:"resHeader"` } diff --git a/go.mod b/go.mod index 0db8234..9fd2063 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/andybalholm/brotli v1.0.6 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect + github.com/gorilla/websocket v1.5.1 github.com/klauspost/compress v1.17.3 // indirect github.com/miekg/dns v1.1.57 // indirect github.com/qydysky/part v0.28.20231202144738 // indirect diff --git a/go.sum b/go.sum index 93bb9eb..ddc43b4 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= diff --git a/main.go b/main.go index 411f3b8..7adf04d 100644 --- a/main.go +++ b/main.go @@ -11,8 +11,10 @@ import ( "os" "os/signal" "slices" + "strings" "time" + "github.com/gorilla/websocket" pctx "github.com/qydysky/part/ctx" pfile "github.com/qydysky/part/file" plog "github.com/qydysky/part/log" @@ -43,9 +45,10 @@ func main() { }) if slices.Contains(os.Args[1:], "-q") { - logger.L(`I:`, "不输出警告") + logger.L(`I:`, "简化输出") delete(logger.Config.Prefix_string, `E:`) delete(logger.Config.Prefix_string, `W:`) + delete(logger.Config.Prefix_string, `T:`) } // 根ctx @@ -198,86 +201,182 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log for i := 0; i < len(configS.Routes); i++ { route := &configS.Routes[i] + path := route.Path if !route.SwapSign() { continue } if len(route.Back) == 0 { - logger.L(`I:`, "移除路由", route.Path) - routeP.Store(route.Path, nil) + logger.L(`I:`, "移除路由", path) + routeP.Store(path, nil) continue } backArray := route.GenBack() if len(backArray) == 0 { - logger.L(`I:`, "移除路由", route.Path) - routeP.Store(route.Path, nil) + logger.L(`I:`, "移除路由", path) + routeP.Store(path, nil) continue } - logger.L(`I:`, "路由更新", route.Path) + logger.L(`I:`, "路由更新", path) - routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) { + routeP.Store(path, func(w http.ResponseWriter, r *http.Request) { ctx1, done1 := pctx.WaitCtx(ctx) defer done1() back := backArray[time.Now().UnixMilli()%int64(len(backArray))] - req, e := http.NewRequestWithContext(ctx1, r.Method, back.To+r.URL.String(), r.Body) - if e != nil { - pweb.WithStatusCode(w, http.StatusServiceUnavailable) - logger.L(`E:`, fmt.Sprintf("%s=>%s %v", route.Path, back.To, e)) - return - } + logger.L(`T:`, fmt.Sprintf("%s=>%s", path, back.To)) - for k, v := range r.Header { - req.Header.Set(k, v[0]) + if r.Header.Get("Upgrade") == "websocket" { + wsDealer(ctx1, w, r, path, back, logger) + } else { + httpDealer(ctx1, w, r, path, back, logger) } + }) + } + return nil +} - for _, v := range back.ReqHeader { - switch v.Action { - case `set`: - req.Header.Set(v.Key, v.Value) - case `add`: - req.Header.Add(v.Key, v.Value) - case `del`: - req.Header.Del(v.Key) - default: - logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", route.Path, back.To, v)) - } - } +func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger *plog.Log_interface) { + url := back.To + if back.PathAdd { + url += r.URL.String() + } - resp, e := http.DefaultClient.Do(req) - if e != nil { - pweb.WithStatusCode(w, http.StatusServiceUnavailable) - logger.L(`E:`, fmt.Sprintf("%s=>%s %v", route.Path, back.To, e)) - return - } + if !strings.HasPrefix(url, "http") { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, "非http")) + return + } - for k, v := range resp.Header { - w.Header().Set(k, v[0]) - } + req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body) + if e != nil { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return + } - for _, v := range back.ResHeader { - switch v.Action { - case `set`: - w.Header().Set(v.Key, v.Value) - case `add`: - w.Header().Add(v.Key, v.Value) - case `del`: - w.Header().Del(v.Key) - default: - logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", route.Path, back.To, v)) - } - } + for k, v := range r.Header { + req.Header.Set(k, v[0]) + } - w.WriteHeader(resp.StatusCode) + for _, v := range back.ReqHeader { + switch v.Action { + case `set`: + req.Header.Set(v.Key, v.Value) + case `add`: + req.Header.Add(v.Key, v.Value) + case `del`: + req.Header.Del(v.Key) + default: + logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v)) + } + } + client := http.Client{} + resp, e := client.Do(req) + if e != nil { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return + } - _, _ = io.Copy(w, resp.Body) - resp.Body.Close() - }) + for k, v := range resp.Header { + w.Header().Set(k, v[0]) + } + + for _, v := range back.ResHeader { + switch v.Action { + case `set`: + w.Header().Set(v.Key, v.Value) + case `add`: + w.Header().Add(v.Key, v.Value) + case `del`: + w.Header().Del(v.Key) + default: + logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v)) + } + } + + w.WriteHeader(resp.StatusCode) + + if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 { + return + } + + w = pweb.WithFlush(w) + if _, e = io.Copy(w, resp.Body); e != nil { + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + } + resp.Body.Close() +} + +func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger *plog.Log_interface) { + url := back.To + if back.PathAdd { + url += r.URL.String() + } + + if !strings.HasPrefix(url, "ws") { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, "非websocket")) + return + } + + reqHeader := make(http.Header) + for _, v := range back.ReqHeader { + switch v.Action { + case `set`: + reqHeader.Set(v.Key, v.Value) + case `add`: + reqHeader.Add(v.Key, v.Value) + case `del`: + reqHeader.Del(v.Key) + default: + logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v)) + } + } + if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + } else { + for _, v := range back.ResHeader { + switch v.Action { + case `set`: + resp.Header.Set(v.Key, v.Value) + case `add`: + resp.Header.Add(v.Key, v.Value) + case `del`: + resp.Header.Del(v.Key) + default: + logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v)) + } + } + + if req, e := (&websocket.Upgrader{}).Upgrade(w, r, resp.Header); e != nil { + pweb.WithStatusCode(w, http.StatusServiceUnavailable) + logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + } else { + fin := make(chan struct{}) + reqc := req.NetConn() + resc := res.NetConn() + go func() { + _, _ = io.Copy(reqc, resc) + fin <- struct{}{} + }() + go func() { + _, _ = io.Copy(resc, reqc) + fin <- struct{}{} + }() + select { + case <-fin: + case <-ctx.Done(): + } + reqc.Close() + resc.Close() + } } - return nil } diff --git a/main.json b/main.json index 92e68c8..ff74276 100644 --- a/main.json +++ b/main.json @@ -6,8 +6,10 @@ "path": "/", "back": [ { + "name": "test", "to": "http://127.0.0.1:13000", "weight": 1, + "pathAdd": false, "resHeader":[ { "action": "set", -- 2.39.2