"net/http"
"net/http/httptrace"
"net/url"
+ "strings"
"time"
_ "unsafe"
for i := 0; i < len(configS.Routes); i++ {
route := &configS.Routes[i]
path := route.Path
- splicing := route.Splicing
- ErrRedirect := route.ErrRedirect
if !route.SwapSign() {
continue
ctx1, done1 := pctx.WaitCtx(ctx)
defer done1()
- var backI *Back
- validHost := validCookieDomain(r.Host)
- if validHost {
+ var backIs []*Back
+ if validCookieDomain(r.Host) {
if t, e := r.Cookie("_psign_" + cookie); e == nil {
if tmp, ok := backMap[t.Value]; ok {
- backI = tmp
+ backIs = append(backIs, tmp)
}
}
}
- if backI == nil {
- backArray := GetBackByRequest(backArray, r)
- switch len(backArray) {
- case 0:
- w.WriteHeader(http.StatusServiceUnavailable)
- logger.Error(`W:`, fmt.Sprintf("%s=> 无匹配", path))
- return
- case 1:
- backI = backArray[0]
- default:
- backI = backArray[nanotime1()%int64(len(backArray))]
- }
- if !backI.IsLive() {
- backI = nil
- for i := 0; i < len(backArray); i++ {
- if backArray[i].IsLive() {
- backI = backArray[i]
- break
- }
- }
- if backI == nil {
- w.WriteHeader(http.StatusServiceUnavailable)
- logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path))
- return
- }
- }
+ if len(backIs) == 0 {
+ backIs = append(backIs, FiliterBackByRequest(backArray, r)...)
}
- if validHost {
- w.Header().Add("Set-Cookie", (&http.Cookie{
- Name: "_psign_" + cookie,
- Value: backI.Sign,
- MaxAge: splicing,
- Path: path,
- Domain: r.Host,
- }).String())
+ if len(backIs) == 0 {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ logger.Error(`W:`, fmt.Sprintf("%s=> 无可用后端", path))
+ return
}
- w.Header().Add("_pto_"+cookie, backI.Name)
-
- logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backI.Name))
-
var e error
if r.Header.Get("Upgrade") == "websocket" {
- e = wsDealer(ctx1, w, r, path, backI, logger, configS.BlocksI)
+ e = wsDealer(ctx1, w, r, path, backIs, logger, configS.BlocksI)
} else {
- e = httpDealer(ctx1, w, r, path, backI, logger, configS.BlocksI)
+ e = httpDealer(ctx1, w, r, path, backIs, logger, configS.BlocksI)
}
- if e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backI.Name, e))
- switch e {
- case ErrCopy:
- backI.Disable()
- return
- case ErrHeaderCheckFail:
- w.WriteHeader(http.StatusForbidden)
- return
- default:
- backI.Disable()
- if ErrRedirect {
- w.Header().Set("Location", r.URL.String())
- w.WriteHeader(http.StatusTemporaryRedirect)
- }
- }
+ if errors.Is(e, ErrHeaderCheckFail) {
+ w.WriteHeader(http.StatusForbidden)
+ return
}
})
}
ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
)
-func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
- url := back.To
- if back.PathAdd {
- url += r.URL.String()
- }
+func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+ var (
+ resp *http.Response
+ chosenBack *Back
+ )
- url = "http" + url
+ for 0 < len(backs) && resp == nil {
+ chosenBack = backs[0]
+ backs = backs[1:]
- for _, v := range back.ReqHeader {
- if v.Action == `check` {
- if r.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
+ url := chosenBack.To
+ if chosenBack.PathAdd {
+ url += r.URL.String()
+ }
+
+ url = "http" + url
+
+ for _, v := range chosenBack.ReqHeader {
+ if v.Action == `check` {
+ if r.Header.Get(v.Key) != v.Value {
+ return ErrHeaderCheckFail
+ }
}
}
- }
- req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
- if e != nil {
- return errors.Join(ErrReqCreFail, e)
+ req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
+ if e != nil {
+ return errors.Join(ErrReqCreFail, e)
+ }
+
+ if e := copyHeader(r.Header, req.Header, chosenBack.ReqHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ return e
+ }
+
+ req.Header.Del("Referer")
+
+ client := http.Client{}
+ resp, e = client.Do(req)
+ if e != nil {
+ chosenBack.Disable()
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ }
}
- if e := copyHeader(r.Header, req.Header, back.ReqHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
- return e
+ if 0 == len(backs) && resp == nil {
+ logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
+ return errors.New("全部后端故障")
+ } else if resp == nil {
+ return errors.New("后端故障")
}
- client := http.Client{}
- resp, e := client.Do(req)
- if e != nil {
- return errors.Join(ErrReqDoFail, e)
+ logger.Error(`T:`, fmt.Sprintf("%s=>%s", routePath, chosenBack.Name))
+
+ if validCookieDomain(r.Host) {
+ w.Header().Add("Set-Cookie", (&http.Cookie{
+ Name: "_psign_" + cookie,
+ Value: chosenBack.Sign,
+ MaxAge: chosenBack.Splicing,
+ Domain: r.Host,
+ }).String())
}
- if e := copyHeader(resp.Header, w.Header(), back.ResHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ w.Header().Add("_pto_"+cookie, chosenBack.Name)
+
+ if e := copyHeader(resp.Header, w.Header(), chosenBack.ResHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
return e
}
defer resp.Body.Close()
if tmpbuf, put, e := blocksi.Get(); e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ chosenBack.Disable()
return errors.Join(ErrCopy, e)
} else {
defer put()
if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ chosenBack.Disable()
return errors.Join(ErrCopy, e)
}
}
return nil
}
-func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, back *Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
- url := back.To
- if back.PathAdd {
- url += r.URL.String()
- }
+func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+ var (
+ resp *http.Response
+ conn net.Conn
+ chosenBack *Back
+ )
- url = "ws" + url
+ for 0 < len(backs) && resp == nil {
+ chosenBack = backs[0]
+ backs = backs[1:]
- reqHeader := make(http.Header)
+ url := chosenBack.To
+ if chosenBack.PathAdd {
+ url += r.URL.String()
+ }
- if e := copyHeader(r.Header, reqHeader, back.ReqHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
- return e
- }
+ url = "ws" + url
- if res, resp, e := DialContext(ctx, url, reqHeader); e != nil {
- return errors.Join(ErrReqDoFail, e)
- } else {
- defer res.Close()
+ reqHeader := make(http.Header)
- resHeader := make(http.Header)
- if e := copyHeader(resp.Header, resHeader, back.ResHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
+ if e := copyHeader(r.Header, reqHeader, chosenBack.ReqHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
return e
}
- if req, e := Upgrade(w, r, resHeader); e != nil {
- return errors.Join(ErrResDoFail, e)
- } else {
- defer req.Close()
+ reqHeader.Del("Referer")
- select {
- case e := <-copyWsMsg(req, res, blocksi):
- if e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, back.Name, e))
- return errors.Join(ErrCopy, e)
- }
- case e := <-copyWsMsg(res, req, blocksi):
- if e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, back.Name, e))
- return errors.Join(ErrCopy, e)
- }
- case <-ctx.Done():
- }
+ var e error
+ conn, resp, e = DialContext(ctx, url, reqHeader)
+ if e != nil {
+ chosenBack.Disable()
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ }
+ }
+
+ if 0 == len(backs) && (resp == nil || conn == nil) {
+ logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
+ return errors.New("全部后端故障")
+ } else if resp == nil || conn == nil {
+ return errors.New("后端故障")
+ }
+
+ logger.Error(`T:`, fmt.Sprintf("%s=>%s", routePath, chosenBack.Name))
+
+ if validCookieDomain(r.Host) {
+ w.Header().Add("Set-Cookie", (&http.Cookie{
+ Name: "_psign_" + cookie,
+ Value: chosenBack.Sign,
+ MaxAge: chosenBack.Splicing,
+ Domain: r.Host,
+ }).String())
+ }
+
+ w.Header().Add("_pto_"+cookie, chosenBack.Name)
+
+ defer conn.Close()
+
+ resHeader := make(http.Header)
+ if e := copyHeader(resp.Header, resHeader, chosenBack.ResHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
+ return e
+ }
+
+ if req, e := Upgrade(w, r, resHeader); e != nil {
+ return errors.Join(ErrResDoFail, e)
+ } else {
+ defer req.Close()
- return nil
+ select {
+ case e := <-copyWsMsg(req, conn, blocksi):
+ if e != nil {
+ chosenBack.Disable()
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, chosenBack.Name, e))
+ return errors.Join(ErrCopy, e)
+ }
+ case e := <-copyWsMsg(conn, req, blocksi):
+ if e != nil {
+ chosenBack.Disable()
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, chosenBack.Name, e))
+ return errors.Join(ErrCopy, e)
+ }
+ case <-ctx.Done():
}
+
+ return nil
}
}
sm := (map[string][]string)(s)
tm := (map[string][]string)(t)
for k, v := range sm {
- tm[k] = v
+ if strings.ToLower(k) == "set-cookie" {
+ cookies := strings.Split(v[0], ";")
+ for k, v := range cookies {
+ if strings.Contains(strings.ToLower(v), "domain=") {
+ cookies = append(cookies[:k], cookies[k+1:]...)
+ break
+ }
+ }
+ tm[k] = []string{strings.Join(cookies, ";")}
+ } else {
+ tm[k] = v
+ }
}
for _, v := range app {
switch v.Action {
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+ log.Default().Println(resp.StatusCode, resp.Header)
return nil, resp, websocket.ErrBadHandshake
}