"errors"
"fmt"
"io"
+ "math/rand"
"net/http"
+ "net/http/cookiejar"
+ "strings"
"time"
_ "unsafe"
Read(data []byte) (int, error)
}
+//go:linkname validCookieDomain net/http.validCookieDomain
+func validCookieDomain(v string) bool
+
// 加载
func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error {
if e := loadConfig(buf, configF, configS); e != nil {
case 1:
backI = backArray[0]
default:
- backI = backArray[time.Now().UnixMilli()%int64(len(backArray))]
+ backI = backArray[rand.Int63()%int64(len(backArray))]
}
if !backI.IsLive() {
backI = nil
req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
if e != nil {
- return ErrReqCreFail
+ return errors.Join(ErrReqCreFail, e)
}
- for k, v := range r.Header {
- req.Header.Set(k, v[0])
+ if e := copyHeader(r.Header, req.Header, back.ReqHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return e
}
- for _, v := range back.ReqHeader {
- switch v.Action {
- case `check`:
- case `set`:
- req.Header.Set(v.Key, v.Value)
- case `add`:
- req.Header.Add(v.Key, v.Value)
- case `del`:
- req.Header.Del(v.Key)
- default:
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
- }
- }
client := http.Client{}
resp, e := client.Do(req)
if e != nil {
- return ErrReqDoFail
+ return errors.Join(ErrReqDoFail, e)
}
- header := w.Header()
- for k, v := range resp.Header {
- if has(&header, k) {
- header.Add(k, v[0])
- } else {
- header.Set(k, v[0])
- }
- }
-
- for _, v := range back.ResHeader {
- switch v.Action {
- case `check`:
- if resp.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
- }
- case `set`:
- header.Set(v.Key, v.Value)
- case `add`:
- header.Add(v.Key, v.Value)
- case `del`:
- header.Del(v.Key)
- default:
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
- }
+ if e := copyHeader(resp.Header, w.Header(), back.ResHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return e
}
w.WriteHeader(resp.StatusCode)
defer resp.Body.Close()
if _, e = io.Copy(w, resp.Body); e != nil {
logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
- return ErrCopy
+ return errors.Join(ErrCopy, e)
}
return nil
}
url = "ws" + url
- reqHeader := make(http.Header)
- for _, v := range back.ReqHeader {
- switch v.Action {
- case `check`:
- if r.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
- }
- case `set`:
- reqHeader.Set(v.Key, v.Value)
- case `add`:
- reqHeader.Add(v.Key, v.Value)
- case `del`:
- reqHeader.Del(v.Key)
- default:
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ReqHeader %v", routePath, back.Name, v))
- }
+ tmpHeader := make(http.Header)
+
+ dialer := websocket.DefaultDialer
+
+ if strings.Contains(r.Header.Get("Sec-Websocket-Extensions"), "permessage-deflate") {
+ dialer.EnableCompression = true
}
- if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil {
- return ErrReqDoFail
+
+ if e := copyHeader(r.Header, tmpHeader, back.ReqHeader, delWsHeaders); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return e
+ }
+
+ jar, _ := cookiejar.New(nil)
+ jar.SetCookies(r.URL, r.Cookies())
+
+ dialer.Jar = jar
+
+ if res, resp, e := dialer.Dial(url, tmpHeader); e != nil {
+ return errors.Join(ErrReqDoFail, e)
} else {
- resp.Header.Del("connection")
- resp.Header.Del("upgrade")
- resp.Header.Del("sec-websocket-accept")
-
- for _, v := range back.ResHeader {
- switch v.Action {
- case `check`:
- if resp.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
- }
- case `set`:
- resp.Header.Set(v.Key, v.Value)
- case `add`:
- resp.Header.Add(v.Key, v.Value)
- case `del`:
- resp.Header.Del(v.Key)
- default:
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
- }
+ defer res.Close()
+
+ clear(tmpHeader)
+ if e := copyHeader(resp.Header, tmpHeader, back.ResHeader, delWsHeaders); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ return e
}
- if req, e := (&websocket.Upgrader{}).Upgrade(w, r, resp.Header); e != nil {
- return ErrResDoFail
+ if req, e := (&websocket.Upgrader{}).Upgrade(w, r, tmpHeader); e != nil {
+ return errors.Join(ErrResDoFail, e)
} else {
- ctx, cancle := pctx.WithWait(ctx, 2, time.Second*45)
- defer func() {
- _ = cancle()
- }()
- fin := make(chan error)
+ defer req.Close()
+
+ fin := make(chan error, 2)
reqc := req.NetConn()
resc := res.NetConn()
go func() {
ctx1, done1 := pctx.WaitCtx(ctx)
defer done1()
- _, e := io.Copy(reqc, resc)
select {
- case fin <- e:
+ case fin <- copyWsMsg(reqc, resc):
case <-ctx1.Done():
fin <- nil
}
- reqc.Close()
}()
go func() {
ctx1, done1 := pctx.WaitCtx(ctx)
defer done1()
- _, e := io.Copy(resc, reqc)
select {
- case fin <- e:
+ case fin <- copyWsMsg(resc, reqc):
case <-ctx1.Done():
fin <- nil
}
- resc.Close()
}()
if e := <-fin; e != nil {
logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
- return ErrCopy
+ return errors.Join(ErrCopy, e)
}
return nil
}
}
}
-//go:linkname has net/http.(*Header).has
-func has(h *http.Header, key string) bool
+var delWsHeaders = map[string]*any{
+ `Connection`: nil,
+ `Upgrade`: nil,
+ `Sec-Websocket-Accept`: nil,
+ `Sec-Websocket-Key`: nil,
+ `Sec-Websocket-Version`: nil,
+ `Sec-Websocket-Protocol`: nil,
+ `Sec-Websocket-Extensions`: nil,
+}
-//go:linkname validCookieDomain net/http.validCookieDomain
-func validCookieDomain(v string) bool
+func copyHeader(s, t http.Header, app []Header, delHeader ...map[string]*any) error {
+ sm := (map[string][]string)(s)
+ tm := (map[string][]string)(t)
+ for k, v := range sm {
+ if _, ok := delWsHeaders[k]; ok {
+ continue
+ }
+ tm[k] = v
+ }
+ for _, v := range app {
+ switch v.Action {
+ case `check`:
+ if val := tm[v.Key]; val[0] != v.Value {
+ return ErrHeaderCheckFail
+ }
+ case `set`:
+ t.Set(v.Key, v.Value)
+ case `add`:
+ t.Add(v.Key, v.Value)
+ case `del`:
+ t.Del(v.Key)
+ default:
+ }
+ }
+ return nil
+}
+
+func copyWsMsg(dst io.Writer, src io.Reader) (e error) {
+ _, e = io.Copy(dst, src)
+ return
+}