"os"
"os/signal"
"slices"
+ "strings"
"time"
+ "github.com/gorilla/websocket"
pctx "github.com/qydysky/part/ctx"
pfile "github.com/qydysky/part/file"
plog "github.com/qydysky/part/log"
})
if slices.Contains(os.Args[1:], "-q") {
- logger.L(`I:`, "不输出警告")
+ logger.L(`I:`, "简化输出")
delete(logger.Config.Prefix_string, `E:`)
delete(logger.Config.Prefix_string, `W:`)
+ delete(logger.Config.Prefix_string, `T:`)
}
// 根ctx
for i := 0; i < len(configS.Routes); i++ {
route := &configS.Routes[i]
+ path := route.Path
if !route.SwapSign() {
continue
}
if len(route.Back) == 0 {
- logger.L(`I:`, "移除路由", route.Path)
- routeP.Store(route.Path, nil)
+ logger.L(`I:`, "移除路由", path)
+ routeP.Store(path, nil)
continue
}
backArray := route.GenBack()
if len(backArray) == 0 {
- logger.L(`I:`, "移除路由", route.Path)
- routeP.Store(route.Path, nil)
+ logger.L(`I:`, "移除路由", path)
+ routeP.Store(path, nil)
continue
}
- logger.L(`I:`, "路由更新", route.Path)
+ logger.L(`I:`, "路由更新", path)
- routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
+ routeP.Store(path, func(w http.ResponseWriter, r *http.Request) {
ctx1, done1 := pctx.WaitCtx(ctx)
defer done1()
back := backArray[time.Now().UnixMilli()%int64(len(backArray))]
- req, e := http.NewRequestWithContext(ctx1, r.Method, back.To+r.URL.String(), r.Body)
- if e != nil {
- pweb.WithStatusCode(w, http.StatusServiceUnavailable)
- logger.L(`E:`, fmt.Sprintf("%s=>%s %v", route.Path, back.To, e))
- return
- }
+ logger.L(`T:`, fmt.Sprintf("%s=>%s", path, back.To))
- for k, v := range r.Header {
- req.Header.Set(k, v[0])
+ if r.Header.Get("Upgrade") == "websocket" {
+ wsDealer(ctx1, w, r, path, back, logger)
+ } else {
+ httpDealer(ctx1, w, r, path, back, logger)
}
+ })
+ }
+ return nil
+}
- for _, v := range back.ReqHeader {
- switch v.Action {
- case `set`:
- req.Header.Set(v.Key, v.Value)
- case `add`:
- req.Header.Add(v.Key, v.Value)
- case `del`:
- req.Header.Del(v.Key)
- default:
- logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", route.Path, back.To, v))
- }
- }
+func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger *plog.Log_interface) {
+ url := back.To
+ if back.PathAdd {
+ url += r.URL.String()
+ }
- resp, e := http.DefaultClient.Do(req)
- if e != nil {
- pweb.WithStatusCode(w, http.StatusServiceUnavailable)
- logger.L(`E:`, fmt.Sprintf("%s=>%s %v", route.Path, back.To, e))
- return
- }
+ if !strings.HasPrefix(url, "http") {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, "非http"))
+ return
+ }
- for k, v := range resp.Header {
- w.Header().Set(k, v[0])
- }
+ req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
+ if e != nil {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return
+ }
- for _, v := range back.ResHeader {
- switch v.Action {
- case `set`:
- w.Header().Set(v.Key, v.Value)
- case `add`:
- w.Header().Add(v.Key, v.Value)
- case `del`:
- w.Header().Del(v.Key)
- default:
- logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", route.Path, back.To, v))
- }
- }
+ for k, v := range r.Header {
+ req.Header.Set(k, v[0])
+ }
- w.WriteHeader(resp.StatusCode)
+ for _, v := range back.ReqHeader {
+ switch v.Action {
+ case `set`:
+ req.Header.Set(v.Key, v.Value)
+ case `add`:
+ req.Header.Add(v.Key, v.Value)
+ case `del`:
+ req.Header.Del(v.Key)
+ default:
+ logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
+ }
+ }
+ client := http.Client{}
+ resp, e := client.Do(req)
+ if e != nil {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return
+ }
- _, _ = io.Copy(w, resp.Body)
- resp.Body.Close()
- })
+ for k, v := range resp.Header {
+ w.Header().Set(k, v[0])
+ }
+
+ for _, v := range back.ResHeader {
+ switch v.Action {
+ case `set`:
+ w.Header().Set(v.Key, v.Value)
+ case `add`:
+ w.Header().Add(v.Key, v.Value)
+ case `del`:
+ w.Header().Del(v.Key)
+ default:
+ logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
+ }
+ }
+
+ w.WriteHeader(resp.StatusCode)
+
+ if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 {
+ return
+ }
+
+ w = pweb.WithFlush(w)
+ if _, e = io.Copy(w, resp.Body); e != nil {
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ }
+ resp.Body.Close()
+}
+
+func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger *plog.Log_interface) {
+ url := back.To
+ if back.PathAdd {
+ url += r.URL.String()
+ }
+
+ if !strings.HasPrefix(url, "ws") {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, "非websocket"))
+ return
+ }
+
+ reqHeader := make(http.Header)
+ for _, v := range back.ReqHeader {
+ switch v.Action {
+ case `set`:
+ reqHeader.Set(v.Key, v.Value)
+ case `add`:
+ reqHeader.Add(v.Key, v.Value)
+ case `del`:
+ reqHeader.Del(v.Key)
+ default:
+ logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
+ }
+ }
+ if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ } else {
+ for _, v := range back.ResHeader {
+ switch v.Action {
+ case `set`:
+ resp.Header.Set(v.Key, v.Value)
+ case `add`:
+ resp.Header.Add(v.Key, v.Value)
+ case `del`:
+ resp.Header.Del(v.Key)
+ default:
+ logger.L(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
+ }
+ }
+
+ if req, e := (&websocket.Upgrader{}).Upgrade(w, r, resp.Header); e != nil {
+ pweb.WithStatusCode(w, http.StatusServiceUnavailable)
+ logger.L(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ } else {
+ fin := make(chan struct{})
+ reqc := req.NetConn()
+ resc := res.NetConn()
+ go func() {
+ _, _ = io.Copy(reqc, resc)
+ fin <- struct{}{}
+ }()
+ go func() {
+ _, _ = io.Copy(resc, reqc)
+ fin <- struct{}{}
+ }()
+ select {
+ case <-fin:
+ case <-ctx.Done():
+ }
+ reqc.Close()
+ resc.Close()
+ }
}
- return nil
}