]> 127.0.0.1 Git - front/.git/commitdiff
1 v0.1.20231212194948
authorqydysky <qydysky@foxmail.com>
Tue, 12 Dec 2023 19:49:26 +0000 (03:49 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 12 Dec 2023 19:49:26 +0000 (03:49 +0800)
main.go
main/main.go

diff --git a/main.go b/main.go
index 5d72b6c6d1a09031511b361069d52d76ede0b0a1..95965e69e400adcd67965f6cbbca6372090dc1ef 100644 (file)
--- a/main.go
+++ b/main.go
@@ -7,7 +7,10 @@ import (
        "errors"
        "fmt"
        "io"
+       "math/rand"
        "net/http"
+       "net/http/cookiejar"
+       "strings"
        "time"
        _ "unsafe"
 
@@ -27,6 +30,9 @@ type File interface {
        Read(data []byte) (int, error)
 }
 
+//go:linkname validCookieDomain net/http.validCookieDomain
+func validCookieDomain(v string) bool
+
 // 加载
 func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error {
        if e := loadConfig(buf, configF, configS); e != nil {
@@ -220,7 +226,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log
                                case 1:
                                        backI = backArray[0]
                                default:
-                                       backI = backArray[time.Now().UnixMilli()%int64(len(backArray))]
+                                       backI = backArray[rand.Int63()%int64(len(backArray))]
                                }
                                if !backI.IsLive() {
                                        backI = nil
@@ -308,56 +314,23 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
 
        req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
        if e != nil {
-               return ErrReqCreFail
+               return errors.Join(ErrReqCreFail, e)
        }
 
-       for k, v := range r.Header {
-               req.Header.Set(k, v[0])
+       if e := copyHeader(r.Header, req.Header, back.ReqHeader); e != nil {
+               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+               return e
        }
 
-       for _, v := range back.ReqHeader {
-               switch v.Action {
-               case `check`:
-               case `set`:
-                       req.Header.Set(v.Key, v.Value)
-               case `add`:
-                       req.Header.Add(v.Key, v.Value)
-               case `del`:
-                       req.Header.Del(v.Key)
-               default:
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
-               }
-       }
        client := http.Client{}
        resp, e := client.Do(req)
        if e != nil {
-               return ErrReqDoFail
+               return errors.Join(ErrReqDoFail, e)
        }
 
-       header := w.Header()
-       for k, v := range resp.Header {
-               if has(&header, k) {
-                       header.Add(k, v[0])
-               } else {
-                       header.Set(k, v[0])
-               }
-       }
-
-       for _, v := range back.ResHeader {
-               switch v.Action {
-               case `check`:
-                       if resp.Header.Get(v.Key) != v.Value {
-                               return ErrHeaderCheckFail
-                       }
-               case `set`:
-                       header.Set(v.Key, v.Value)
-               case `add`:
-                       header.Add(v.Key, v.Value)
-               case `del`:
-                       header.Del(v.Key)
-               default:
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
-               }
+       if e := copyHeader(resp.Header, w.Header(), back.ResHeader); e != nil {
+               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+               return e
        }
 
        w.WriteHeader(resp.StatusCode)
@@ -369,7 +342,7 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
        defer resp.Body.Close()
        if _, e = io.Copy(w, resp.Body); e != nil {
                logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
-               return ErrCopy
+               return errors.Join(ErrCopy, e)
        }
        return nil
 }
@@ -382,90 +355,108 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
 
        url = "ws" + url
 
-       reqHeader := make(http.Header)
-       for _, v := range back.ReqHeader {
-               switch v.Action {
-               case `check`:
-                       if r.Header.Get(v.Key) != v.Value {
-                               return ErrHeaderCheckFail
-                       }
-               case `set`:
-                       reqHeader.Set(v.Key, v.Value)
-               case `add`:
-                       reqHeader.Add(v.Key, v.Value)
-               case `del`:
-                       reqHeader.Del(v.Key)
-               default:
-                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
-               }
+       tmpHeader := make(http.Header)
+
+       dialer := websocket.DefaultDialer
+
+       if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") {
+               dialer.EnableCompression = true
        }
-       if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil {
-               return ErrReqDoFail
+
+       if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil {
+               logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+               return e
+       }
+
+       jar, _ := cookiejar.New(nil)
+       jar.SetCookies(r.URL, r.Cookies())
+
+       dialer.Jar = jar
+
+       if res, resp, e := dialer.Dial(url, tmpHeader); e != nil {
+               return errors.Join(ErrReqDoFail, e)
        } else {
-               resp.Header.Del("connection")
-               resp.Header.Del("upgrade")
-               resp.Header.Del("sec-websocket-accept")
-
-               for _, v := range back.ResHeader {
-                       switch v.Action {
-                       case `check`:
-                               if resp.Header.Get(v.Key) != v.Value {
-                                       return ErrHeaderCheckFail
-                               }
-                       case `set`:
-                               resp.Header.Set(v.Key, v.Value)
-                       case `add`:
-                               resp.Header.Add(v.Key, v.Value)
-                       case `del`:
-                               resp.Header.Del(v.Key)
-                       default:
-                               logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
-                       }
+               defer res.Close()
+
+               clear(tmpHeader)
+               if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+                       return e
                }
 
-               if req, e := (&websocket.Upgrader{}).Upgrade(w, r, resp.Header); e != nil {
-                       return ErrResDoFail
+               if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil {
+                       return errors.Join(ErrResDoFail, e)
                } else {
-                       ctx, cancle := pctx.WithWait(ctx, 2, time.Second*45)
-                       defer func() {
-                               _ = cancle()
-                       }()
-                       fin := make(chan error)
+                       defer req.Close()
+
+                       fin := make(chan error, 2)
                        reqc := req.NetConn()
                        resc := res.NetConn()
                        go func() {
                                ctx1, done1 := pctx.WaitCtx(ctx)
                                defer done1()
-                               _, e := io.Copy(reqc, resc)
                                select {
-                               case fin <- e:
+                               case fin <- copyWsMsg(reqc, resc):
                                case <-ctx1.Done():
                                        fin <- nil
                                }
-                               reqc.Close()
                        }()
                        go func() {
                                ctx1, done1 := pctx.WaitCtx(ctx)
                                defer done1()
-                               _, e := io.Copy(resc, reqc)
                                select {
-                               case fin <- e:
+                               case fin <- copyWsMsg(resc, reqc):
                                case <-ctx1.Done():
                                        fin <- nil
                                }
-                               resc.Close()
                        }()
                        if e := <-fin; e != nil {
                                logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
-                               return ErrCopy
+                               return errors.Join(ErrCopy, e)
                        }
                        return nil
                }
        }
 }
 
-//go:linkname has net/http.(*Header).has
-func has(h *http.Header, key string) bool
+var delWsHeaders = map[string]*any{
+       `Connection`:               nil,
+       `Upgrade`:                  nil,
+       `Sec-Websocket-Accept`:     nil,
+       `Sec-Websocket-Key`:        nil,
+       `Sec-Websocket-Version`:    nil,
+       `Sec-Websocket-Protocol`:   nil,
+       `Sec-Websocket-Extensions`: nil,
+}
 
-//go:linkname validCookieDomain net/http.validCookieDomain
-func validCookieDomain(v string) bool
+func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error {
+       sm := (map[string][]string)(s)
+       tm := (map[string][]string)(t)
+       for k, v := range sm {
+               if _, ok := delWsHeaders[k]; ok {
+                       continue
+               }
+               tm[k] = v
+       }
+       for _, v := range app {
+               switch v.Action {
+               case `check`:
+                       if val := tm[v.Key]; val[0] != v.Value {
+                               return ErrHeaderCheckFail
+                       }
+               case `set`:
+                       t.Set(v.Key, v.Value)
+               case `add`:
+                       t.Add(v.Key, v.Value)
+               case `del`:
+                       t.Del(v.Key)
+               default:
+               }
+       }
+       return nil
+}
+
+func copyWsMsg(dst io.Writer, src io.Reader) (e error) {
+       _, e = io.Copy(dst, src)
+       return
+}
index 0f45c054162763462ffb693b0b333c1b2906b836..7ad9803f7b1be2088f589a05db0792c3df081526 100644 (file)
@@ -75,6 +75,11 @@ func main() {
        var interrupt = make(chan os.Signal, 2)
        signal.Notify(interrupt, os.Interrupt)
        <-interrupt
+       logger.L(`I:`, "退出中,再次ctrl+c强制退出")
+       go func() {
+               <-interrupt
+               os.Exit(1)
+       }()
        if errors.Is(cancle(), pctx.ErrWaitTo) {
                logger.L(`E:`, "退出超时")
        }