From a4328da8a4a7e23dcd70b64ba4546f320ed1a13f Mon Sep 17 00:00:00 2001 From: qydysky Date: Wed, 13 Dec 2023 03:49:26 +0800 Subject: [PATCH] 1 --- main.go | 187 ++++++++++++++++++++++++--------------------------- main/main.go | 5 ++ 2 files changed, 94 insertions(+), 98 deletions(-) diff --git a/main.go b/main.go index 5d72b6c..95965e6 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,10 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" + "net/http/cookiejar" + "strings" "time" _ "unsafe" @@ -27,6 +30,9 @@ type File interface { Read(data []byte) (int, error) } +//go:linkname validCookieDomain net/http.validCookieDomain +func validCookieDomain(v string) bool + // 加载 func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error { if e := loadConfig(buf, configF, configS); e != nil { @@ -220,7 +226,7 @@ func applyConfig(ctx context.Context, configS *Config, routeP *pweb.WebPath, log case 1: backI = backArray[0] default: - backI = backArray[time.Now().UnixMilli()%int64(len(backArray))] + backI = backArray[rand.Int63()%int64(len(backArray))] } if !backI.IsLive() { backI = nil @@ -308,56 +314,23 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body) if e != nil { - return ErrReqCreFail + return errors.Join(ErrReqCreFail, e) } - for k, v := range r.Header { - req.Header.Set(k, v[0]) + if e := copyHeader(r.Header, req.Header, back.ReqHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return e } - for _, v := range back.ReqHeader { - switch v.Action { - case `check`: - case `set`: - req.Header.Set(v.Key, v.Value) - case `add`: - req.Header.Add(v.Key, v.Value) - case `del`: - req.Header.Del(v.Key) - default: - logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v)) - } - } client := http.Client{} resp, e := client.Do(req) if e != nil { - return ErrReqDoFail + return errors.Join(ErrReqDoFail, e) } - header := w.Header() - for k, v := range resp.Header { - if has(&header, k) { - header.Add(k, v[0]) - } else { - header.Set(k, v[0]) - } - } - - for _, v := range back.ResHeader { - switch v.Action { - case `check`: - if resp.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail - } - case `set`: - header.Set(v.Key, v.Value) - case `add`: - header.Add(v.Key, v.Value) - case `del`: - header.Del(v.Key) - default: - logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v)) - } + if e := copyHeader(resp.Header, w.Header(), back.ResHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return e } w.WriteHeader(resp.StatusCode) @@ -369,7 +342,7 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou defer resp.Body.Close() if _, e = io.Copy(w, resp.Body); e != nil { logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) - return ErrCopy + return errors.Join(ErrCopy, e) } return nil } @@ -382,90 +355,108 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route url = "ws" + url - reqHeader := make(http.Header) - for _, v := range back.ReqHeader { - switch v.Action { - case `check`: - if r.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail - } - case `set`: - reqHeader.Set(v.Key, v.Value) - case `add`: - reqHeader.Add(v.Key, v.Value) - case `del`: - reqHeader.Del(v.Key) - default: - logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v)) - } + tmpHeader := make(http.Header) + + dialer := websocket.DefaultDialer + + if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") { + dialer.EnableCompression = true } - if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil { - return ErrReqDoFail + + if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return e + } + + jar, _ := cookiejar.New(nil) + jar.SetCookies(r.URL, r.Cookies()) + + dialer.Jar = jar + + if res, resp, e := dialer.Dial(url, tmpHeader); e != nil { + return errors.Join(ErrReqDoFail, e) } else { - resp.Header.Del("connection") - resp.Header.Del("upgrade") - resp.Header.Del("sec-websocket-accept") - - for _, v := range back.ResHeader { - switch v.Action { - case `check`: - if resp.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail - } - case `set`: - resp.Header.Set(v.Key, v.Value) - case `add`: - resp.Header.Add(v.Key, v.Value) - case `del`: - resp.Header.Del(v.Key) - default: - logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v)) - } + defer res.Close() + + clear(tmpHeader) + if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) + return e } - if req, e := (&websocket.Upgrader{}).Upgrade(w, r, resp.Header); e != nil { - return ErrResDoFail + if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil { + return errors.Join(ErrResDoFail, e) } else { - ctx, cancle := pctx.WithWait(ctx, 2, time.Second*45) - defer func() { - _ = cancle() - }() - fin := make(chan error) + defer req.Close() + + fin := make(chan error, 2) reqc := req.NetConn() resc := res.NetConn() go func() { ctx1, done1 := pctx.WaitCtx(ctx) defer done1() - _, e := io.Copy(reqc, resc) select { - case fin <- e: + case fin <- copyWsMsg(reqc, resc): case <-ctx1.Done(): fin <- nil } - reqc.Close() }() go func() { ctx1, done1 := pctx.WaitCtx(ctx) defer done1() - _, e := io.Copy(resc, reqc) select { - case fin <- e: + case fin <- copyWsMsg(resc, reqc): case <-ctx1.Done(): fin <- nil } - resc.Close() }() if e := <-fin; e != nil { logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e)) - return ErrCopy + return errors.Join(ErrCopy, e) } return nil } } } -//go:linkname has net/http.(*Header).has -func has(h *http.Header, key string) bool +var delWsHeaders = map[string]*any{ + `Connection`: nil, + `Upgrade`: nil, + `Sec-Websocket-Accept`: nil, + `Sec-Websocket-Key`: nil, + `Sec-Websocket-Version`: nil, + `Sec-Websocket-Protocol`: nil, + `Sec-Websocket-Extensions`: nil, +} -//go:linkname validCookieDomain net/http.validCookieDomain -func validCookieDomain(v string) bool +func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error { + sm := (map[string][]string)(s) + tm := (map[string][]string)(t) + for k, v := range sm { + if _, ok := delWsHeaders[k]; ok { + continue + } + tm[k] = v + } + for _, v := range app { + switch v.Action { + case `check`: + if val := tm[v.Key]; val[0] != v.Value { + return ErrHeaderCheckFail + } + case `set`: + t.Set(v.Key, v.Value) + case `add`: + t.Add(v.Key, v.Value) + case `del`: + t.Del(v.Key) + default: + } + } + return nil +} + +func copyWsMsg(dst io.Writer, src io.Reader) (e error) { + _, e = io.Copy(dst, src) + return +} diff --git a/main/main.go b/main/main.go index 0f45c05..7ad9803 100644 --- a/main/main.go +++ b/main/main.go @@ -75,6 +75,11 @@ func main() { var interrupt = make(chan os.Signal, 2) signal.Notify(interrupt, os.Interrupt) <-interrupt + logger.L(`I:`, "退出中,再次ctrl+c强制退出") + go func() { + <-interrupt + os.Exit(1) + }() if errors.Is(cancle(), pctx.ErrWaitTo) { logger.L(`E:`, "退出超时") } -- 2.39.2