"crypto/tls"
"encoding/json"
"fmt"
+ "net/http"
"sync"
"time"
)
type Route struct {
Path string `json:"path"`
Sign string `json:"-"`
+ Splicing int `json:"splicing"`
ErrRedirect bool `json:"errRedirect"`
Back []Back `json:"back"`
}
var backLink []*Back
for i := 0; i < len(t.Back); i++ {
back := &t.Back[i]
+ back.SwapSign()
+ if back.Weight == 0 {
+ continue
+ }
tmpBack := Back{
- Name: back.Name,
- To: back.To,
- Weight: back.Weight,
- ErrBanSec: back.ErrBanSec,
- PathAdd: back.PathAdd,
- ReqHeader: append([]Header{}, back.ReqHeader...),
- ResHeader: append([]Header{}, back.ResHeader...),
+ Name: back.Name,
+ Sign: back.Sign,
+ To: back.To,
+ Weight: back.Weight,
+ ErrBanSec: back.ErrBanSec,
+ PathAdd: back.PathAdd,
+ MatchHeader: append([]Header{}, back.MatchHeader...),
+ ReqHeader: append([]Header{}, back.ReqHeader...),
+ ResHeader: append([]Header{}, back.ResHeader...),
}
for i := 1; i <= back.Weight; i++ {
backLink = append(backLink, &tmpBack)
return backLink
}
+func GetBackByRequest(backs []*Back, r *http.Request) []*Back {
+ var backLink []*Back
+ for i := 0; i < len(backs); i++ {
+ matchs := len(backs[i].MatchHeader) - 1
+ for ; matchs >= 0 &&
+ r.Header.Get(backs[i].MatchHeader[matchs].Key) == backs[i].MatchHeader[matchs].Value; matchs -= 1 {
+ }
+ if matchs == -1 {
+ backLink = append(backLink, backs[i])
+ }
+ }
+ return backLink
+}
+
type Back struct {
- lock sync.RWMutex
- upT time.Time
- Name string `json:"name"`
- To string `json:"to"`
- Weight int `json:"weight"`
- ErrBanSec int `json:"errBanSec"`
- PathAdd bool `json:"pathAdd"`
- ReqHeader []Header `json:"reqHeader"`
- ResHeader []Header `json:"resHeader"`
+ lock sync.RWMutex
+ Sign string `json:"-"`
+ upT time.Time
+ Name string `json:"name"`
+ To string `json:"to"`
+ Weight int `json:"weight"`
+ ErrBanSec int `json:"errBanSec"`
+ PathAdd bool `json:"pathAdd"`
+ MatchHeader []Header `json:"matchHeader"`
+ ReqHeader []Header `json:"reqHeader"`
+ ResHeader []Header `json:"resHeader"`
+}
+
+func (t *Back) SwapSign() bool {
+ data, _ := json.Marshal(t)
+ w := md5.New()
+ w.Write(data)
+ sign := fmt.Sprintf("%x", w.Sum(nil))
+ if t.Sign != sign {
+ t.Sign = sign
+ return true
+ }
+ return false
}
func (t *Back) IsLive() bool {
}
func (t *Back) Disable() {
+ if t.ErrBanSec == 0 {
+ return
+ }
t.lock.Lock()
defer t.lock.Unlock()
t.upT = time.Now().Add(time.Second * time.Duration(t.ErrBanSec))
"fmt"
"io"
"net/http"
- "strings"
"time"
+ _ "unsafe"
"github.com/gorilla/websocket"
pctx "github.com/qydysky/part/ctx"
<-ctx1.Done()
}
+var cookie = fmt.Sprintf("%p", &struct{}{})
+
// 转发
func Run(ctx context.Context, configSP *Config, logger Logger) {
// 根ctx
for i := 0; i < len(configS.Routes); i++ {
route := &configS.Routes[i]
path := route.Path
+ splicing := route.Splicing
ErrRedirect := route.ErrRedirect
if !route.SwapSign() {
continue
}
+ backMap := make(map[string]*Back)
+
+ for i := 0; i < len(backArray); i++ {
+ backMap[backArray[i].Sign] = backArray[i]
+ }
+
logger.Info(`I:`, "路由更新", path)
routeP.Store(path, func(w http.ResponseWriter, r *http.Request) {
ctx1, done1 := pctx.WaitCtx(ctx)
defer done1()
- now := time.Now()
- backI := now.UnixMilli() % int64(len(backArray))
-
- if !backArray[backI].IsLive() {
- for backI = 0; backI < int64(len(backArray)); backI++ {
- if backArray[backI].IsLive() {
- break
+ var backI *Back
+ validHost := validCookieDomain(r.Host)
+ if validHost {
+ if t, e := r.Cookie("_psign_" + cookie); e == nil {
+ if tmp, ok := backMap[t.Value]; ok {
+ backI = tmp
}
}
- if backI == int64(len(backArray)) {
+ }
+
+ if backI == nil {
+ backArray := GetBackByRequest(backArray, r)
+ switch len(backArray) {
+ case 0:
w.WriteHeader(http.StatusServiceUnavailable)
- logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path))
+ logger.Error(`W:`, fmt.Sprintf("%s=> 无匹配", path))
return
+ case 1:
+ backI = backArray[0]
+ default:
+ backI = backArray[time.Now().UnixMilli()%int64(len(backArray))]
+ }
+ if !backI.IsLive() {
+ backI = nil
+ for i := 0; i < len(backArray); i++ {
+ if backArray[i].IsLive() {
+ backI = backArray[i]
+ break
+ }
+ }
+ if backI == nil {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ logger.Error(`E:`, fmt.Sprintf("%s=> 全部后端失效", path))
+ return
+ }
}
}
- logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backArray[backI].Name))
+ if validHost {
+ w.Header().Add("Set-Cookie", (&http.Cookie{
+ Name: "_psign_" + cookie,
+ Value: backI.Sign,
+ MaxAge: splicing,
+ Path: path,
+ }).String())
+ }
+
+ w.Header().Add("_pto_"+cookie, backI.Name)
+
+ logger.Error(`T:`, fmt.Sprintf("%s=>%s", path, backI.Name))
var e error
if r.Header.Get("Upgrade") == "websocket" {
- e = wsDealer(ctx1, w, r, path, backArray[backI], logger)
+ e = wsDealer(ctx1, w, r, path, backI, logger)
} else {
- e = httpDealer(ctx1, w, r, path, backArray[backI], logger)
+ e = httpDealer(ctx1, w, r, path, backI, logger)
}
if e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backArray[backI].Name, e))
+ logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", path, backI.Name, e))
switch e {
case ErrCopy:
- backArray[backI].Disable()
+ backI.Disable()
return
case ErrHeaderCheckFail:
w.WriteHeader(http.StatusForbidden)
return
default:
- backArray[backI].Disable()
+ backI.Disable()
if ErrRedirect {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusTemporaryRedirect)
url += r.URL.String()
}
- if !strings.HasPrefix(url, "http") {
- return ErrNoHttp
+ url = "http" + url
+
+ for _, v := range back.ReqHeader {
+ if v.Action == `check` {
+ if r.Header.Get(v.Key) != v.Value {
+ return ErrHeaderCheckFail
+ }
+ }
}
req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
for _, v := range back.ReqHeader {
switch v.Action {
case `check`:
- if req.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
- }
case `set`:
req.Header.Set(v.Key, v.Value)
case `add`:
return ErrReqDoFail
}
+ header := w.Header()
for k, v := range resp.Header {
- w.Header().Set(k, v[0])
+ if has(&header, k) {
+ header.Add(k, v[0])
+ } else {
+ header.Set(k, v[0])
+ }
}
for _, v := range back.ResHeader {
return ErrHeaderCheckFail
}
case `set`:
- w.Header().Set(v.Key, v.Value)
+ header.Set(v.Key, v.Value)
case `add`:
- w.Header().Add(v.Key, v.Value)
+ header.Add(v.Key, v.Value)
case `del`:
- w.Header().Del(v.Key)
+ header.Del(v.Key)
default:
logger.Warn(`W:`, fmt.Sprintf("%s=>%s 无效ResHeader %v", routePath, back.Name, v))
}
url += r.URL.String()
}
- if !strings.HasPrefix(url, "ws") {
- return ErrNoWs
- }
+ url = "ws" + url
reqHeader := make(http.Header)
for _, v := range back.ReqHeader {
if res, resp, e := websocket.DefaultDialer.Dial(url, reqHeader); e != nil {
return ErrReqDoFail
} else {
+ resp.Header.Del("connection")
+ resp.Header.Del("upgrade")
+ resp.Header.Del("sec-websocket-accept")
+
for _, v := range back.ResHeader {
switch v.Action {
case `check`:
}
}
}
+
+//go:linkname has net/http.(*Header).has
+func has(h *http.Header, key string) bool
+
+//go:linkname validCookieDomain net/http.validCookieDomain
+func validCookieDomain(v string) bool