]> 127.0.0.1 Git - front/.git/commitdiff
1 v0.1.20240305172524
authorqydysky <qydysky@foxmail.com>
Tue, 5 Mar 2024 17:25:01 +0000 (01:25 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 5 Mar 2024 17:25:01 +0000 (01:25 +0800)
README.md
config.go
main.go

index 2738cd5623ac2d86874e683d98b3a03835940c6a..6d6ade19f694a907e89ffc3231f5971e17ccb5be 100755 (executable)
--- a/README.md
+++ b/README.md
@@ -7,11 +7,15 @@
 - 自定权重
 - 故障转移
 - 自定义头
+- 请求头过滤
+- 请求数据过滤
 
 支持嵌入到其他项目中/独立运行
 
 配置为json数组格式[],下面为数组中的其中一个{},注意此级不会动态增加/移除
 
+config:
+
 - addr: string 监听端口 例:0.0.0.0:8081
 - matchRule: string 匹配规则 prefix:当未匹配到时,返回最近的/匹配, all:当未匹配到时,返回404
 - copyBlocks: int 转发的块数量,默认1000
     - path: string 路径
     - splicing: int 当客户端支持cookie时,将会固定使用后端多少秒
     - pathAdd: bool 将客户端访问的路径附加在path上 例:/api/req => /ws => /ws/api/req
-    - matchHeader: [] 将会在back前匹配,匹配客户端请求头,只有都匹配才使用此路由, 可以动态增加/删除
-        - key: string 要匹配的header名
-        - matchExp: string 要匹配的正则式
-        - value: string 要匹配的值
-    - reqHeader: [] 将会附加到每个backs前,请求后端时,请求头处理器, 可以动态增加/删除
-        - action: string 可选check、replace、add、del、set。
-        - key: string 具体处理哪个头
-        - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-        - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
-    - resHeader: [] 将会附加到每个backs前,返回后端的响应时,请求头处理器, 可以动态增加/删除
-        - action: string 可选check、add、del、set。
-        - key: string 具体处理哪个头
-        - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-        - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+    - matcher 将会附加到每个backs前
     - backs: [] 后端, 可以动态增加/删除
         - name: string 后端名称,将在日志中显示
         - to: string 后端地址,例"s://www.baidu.com",会根据客户端自动添加http or ws在地址前
         - weight: int 权重,按routes中的全部back的权重比分配,当权重变为0时,将停止新请求的进入
         - errBanSec: int 当后端错误时(指连接失败,不指后端错误响应),将会禁用若干秒
-        - matchHeader: [] 匹配客户端请求头,只有都匹配才使用此后端, 可以动态增加/删除
-            - key: string 要匹配的header名
-            - matchExp: string 要匹配的正则式
-            - value: string 要匹配的值
-        - reqHeader: [] 将会附加到每个backs前,请求后端时,请求头处理器, 可以动态增加/删除
-            - action: string 可选check、replace、add、del、set。
-            - key: string 具体处理哪个头
-            - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-            - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
-        - resHeader: [] 将会附加到每个backs前,返回后端的响应时,请求头处理器, 可以动态增加/删除
-            - action: string 可选check、add、del、set。
-            - key: string 具体处理哪个头
-            - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
-            - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+        - matcher
+
+matcher:
+
+- matchHeader: [] 匹配客户端请求头,只有都匹配才使用此后端, 可以动态增加/删除
+    - key: string 要匹配的header名
+    - matchExp: string 要匹配的正则式
+    - value: string 要匹配的值
+- reqHeader: [] 请求后端前,请求头处理器, 可以动态增加/删除
+    - action: string 可选check、replace、add、del、set。
+    - key: string 具体处理哪个头
+    - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
+    - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+- resHeader: [] 返回后端的响应前,请求头处理器, 可以动态增加/删除
+    - action: string 可选check、add、del、set。
+    - key: string 具体处理哪个头
+    - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
+    - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+- reqBody: [] 请求后端前,请求数据过滤器, 可以动态增加/删除
+    - action: string 可选access,deny。
+    - reqSize:string 限定请求数据大小,默认为"1M"
+    - matchExp: string access时如不匹配将结束请求。deny时如匹配将结束请求。
index 63891506b25acc3988be7ff38ea236bbc9d35a2d..d295606cc7c03829a5542a027f2718625fb619f0 100755 (executable)
--- a/config.go
+++ b/config.go
@@ -1,17 +1,21 @@
 package front
 
 import (
+       "bytes"
        "context"
        "crypto/tls"
        "errors"
        "fmt"
+       "io"
        "math/rand/v2"
        "net/http"
        "regexp"
        "sync"
        "time"
 
+       "github.com/dustin/go-humanize"
        pctx "github.com/qydysky/part/ctx"
+       pio "github.com/qydysky/part/io"
        pslice "github.com/qydysky/part/slice"
        pweb "github.com/qydysky/part/web"
 )
@@ -26,8 +30,9 @@ type Config struct {
        MatchRule  string               `json:"matchRule"`
        CopyBlocks int                  `json:"copyBlocks"`
        BlocksI    pslice.BlocksI[byte] `json:"-"`
-       oldRoutes  []*Route             `json:"-"`
-       Routes     []Route              `json:"routes"`
+
+       routeMap sync.Map `json:"-"`
+       Routes   []Route  `json:"routes"`
 }
 
 func (t *Config) Run(ctx context.Context, logger Logger) {
@@ -65,23 +70,26 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
        syncWeb := pweb.NewSyncMap(&httpSer, &routeP, matchfunc)
        defer syncWeb.Shutdown()
 
-       var addRoute = func(route *Route) {
-               logger.Info(`I:`, "路由加载", route.Path)
+       var addRoute = func(k string, route *Route) {
+               logger.Info(`I:`, "路由加载", k)
+               t.routeMap.Store(k, route)
+
                routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
                        ctx1, done1 := pctx.WaitCtx(ctx)
                        defer done1()
 
-                       if !Matched(route.MatchHeader, r) {
+                       if !HeaderMatchs(route.MatchHeader, r) {
                                w.WriteHeader(http.StatusNotFound)
                        }
 
                        var backIs []*Back
                        if t, e := r.Cookie("_psign_" + cookie); e == nil {
-                               if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && Matched(backP.(*Back).MatchHeader, r) {
+                               if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).MatchHeader, r) {
                                        backP.(*Back).PathAdd = route.PathAdd
                                        backP.(*Back).Splicing = route.Splicing
-                                       backP.(*Back).ReqHeader = append(route.ReqHeader, backP.(*Back).ReqHeader...)
-                                       backP.(*Back).ResHeader = append(route.ResHeader, backP.(*Back).ResHeader...)
+                                       backP.(*Back).tmp.ReqHeader = append(route.ReqHeader, backP.(*Back).ReqHeader...)
+                                       backP.(*Back).tmp.ResHeader = append(route.ResHeader, backP.(*Back).ResHeader...)
+                                       backP.(*Back).tmp.ReqBody = append(route.ReqBody, backP.(*Back).ReqBody...)
                                        for i := 0; i < backP.(*Back).Weight; i++ {
                                                backIs = append(backIs, backP.(*Back))
                                        }
@@ -101,16 +109,17 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
                        } else {
                                e = httpDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI)
                        }
-                       if errors.Is(e, ErrHeaderCheckFail) {
+                       if errors.Is(e, ErrHeaderCheckFail) || errors.Is(e, ErrBodyCheckFail) {
                                w.WriteHeader(http.StatusForbidden)
                                return
                        }
                })
        }
 
-       var delRoute = func(route *Route) {
-               logger.Info(`I:`, "路由移除", route.Path)
-               routeP.Store(route.Path, nil)
+       var delRoute = func(k string, route *Route) {
+               logger.Info(`I:`, "路由移除", k)
+               t.routeMap.Delete(k)
+               routeP.Store(k, nil)
        }
 
        t.SwapSign(addRoute, delRoute, logger)
@@ -125,66 +134,58 @@ func (t *Config) Run(ctx context.Context, logger Logger) {
        }
 }
 
-func (t *Config) SwapSign(add func(*Route), del func(*Route), logger Logger) {
-       for i := 0; i < len(t.oldRoutes); i++ {
+func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), logger Logger) {
+       t.routeMap.Range(func(key, value any) bool {
                var exist bool
                for k := 0; k < len(t.Routes); k++ {
-                       if t.oldRoutes[i].Path == t.Routes[k].Path {
+                       if key.(string) == t.Routes[k].Path {
                                exist = true
                                break
                        }
                }
                if !exist {
-                       del(t.oldRoutes[i])
+                       del(key.(string), value.(*Route))
                }
-       }
+               return true
+       })
 
        for i := 0; i < len(t.Routes); i++ {
-               var exist bool
-               for k := 0; k < len(t.oldRoutes); k++ {
-                       if t.Routes[i].Path == t.oldRoutes[k].Path {
-                               exist = true
-                               break
-                       }
-               }
-               if !exist {
-                       add(&t.Routes[i])
+               if _, ok := t.routeMap.Load(t.Routes[i].Path); !ok {
+                       add(t.Routes[i].Path, &t.Routes[i])
                }
        }
 
-       t.oldRoutes = t.oldRoutes[:0]
-
        for i := 0; i < len(t.Routes); i++ {
                t.Routes[i].SwapSign(
-                       func(b *Back) {
+                       func(k string, b *Back) {
                                logger.Info(`I:`, "后端加载", t.Routes[i].Path, b.Name)
-                               t.Routes[i].backMap.Store(b.Id(), b)
+                               t.Routes[i].backMap.Store(k, b)
                        },
-                       func(b *Back) {
+                       func(k string, b *Back) {
                                logger.Info(`I:`, "后端移除", t.Routes[i].Path, b.Name)
-                               t.Routes[i].backMap.Delete(b.Id())
+                               t.Routes[i].backMap.Delete(k)
                        },
                        logger,
                )
-               t.oldRoutes = append(t.oldRoutes, &t.Routes[i])
        }
 }
 
 type Route struct {
        Path string `json:"path"`
 
-       Splicing    int      `json:"splicing"`
-       PathAdd     bool     `json:"pathAdd"`
-       MatchHeader []Header `json:"matchHeader"`
-       ReqHeader   []Header `json:"reqHeader"`
-       ResHeader   []Header `json:"resHeader"`
+       Splicing int  `json:"splicing"`
+       PathAdd  bool `json:"pathAdd"`
+       Matcher
 
        backMap sync.Map `json:"-"`
        Backs   []Back   `json:"backs"`
 }
 
-func (t *Route) SwapSign(add func(*Back), del func(*Back), logger Logger) {
-       logger.Info(t.Path)
+func (t *Route) Id() string {
+       return fmt.Sprintf("%p", t)
+}
+
+func (t *Route) SwapSign(add func(string, *Back), del func(string, *Back), logger Logger) {
        t.backMap.Range(func(key, value any) bool {
                var exist bool
                for k := 0; k < len(t.Backs); k++ {
@@ -194,14 +195,14 @@ func (t *Route) SwapSign(add func(*Back), del func(*Back), logger Logger) {
                        }
                }
                if !exist {
-                       del(value.(*Back))
+                       del(key.(string), value.(*Back))
                }
                return true
        })
 
        for i := 0; i < len(t.Backs); i++ {
                if _, ok := t.backMap.Load(t.Backs[i].Id()); !ok {
-                       add(&t.Backs[i])
+                       add(t.Backs[i].Id(), &t.Backs[i])
                }
        }
 }
@@ -236,11 +237,12 @@ func (t *Route) SwapSign(add func(*Back), del func(*Back), logger Logger) {
 func (t *Route) FiliterBackByRequest(r *http.Request) []*Back {
        var backLink []*Back
        for i := 0; i < len(t.Backs); i++ {
-               if t.Backs[i].IsLive() && Matched(t.Backs[i].MatchHeader, r) {
+               if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].MatchHeader, r) {
                        t.Backs[i].PathAdd = t.PathAdd
                        t.Backs[i].Splicing = t.Splicing
-                       t.Backs[i].ReqHeader = append(t.ReqHeader, t.Backs[i].ReqHeader...)
-                       t.Backs[i].ResHeader = append(t.ResHeader, t.Backs[i].ResHeader...)
+                       t.Backs[i].tmp.ReqHeader = append(t.ReqHeader, t.Backs[i].ReqHeader...)
+                       t.Backs[i].tmp.ResHeader = append(t.ResHeader, t.Backs[i].ResHeader...)
+                       t.Backs[i].tmp.ReqBody = append(t.ReqBody, t.Backs[i].ReqBody...)
                        for k := 0; k < t.Backs[i].Weight; k++ {
                                backLink = append(backLink, &t.Backs[i])
                        }
@@ -261,37 +263,35 @@ type Back struct {
        Weight    int    `json:"weight"`
        ErrBanSec int    `json:"errBanSec"`
 
-       Splicing    int      `json:"-"`
-       PathAdd     bool     `json:"-"`
-       MatchHeader []Header `json:"matchHeader"`
-       ReqHeader   []Header `json:"reqHeader"`
-       ResHeader   []Header `json:"resHeader"`
+       Splicing int  `json:"-"`
+       PathAdd  bool `json:"-"`
+       Matcher
+       tmp Matcher `json:"-"`
 }
 
 func (t *Back) Id() string {
        return fmt.Sprintf("%p", t)
 }
 
-func Matched(matchHeader []Header, r *http.Request) bool {
+func HeaderMatchs(matchHeader []Header, r *http.Request) bool {
        matchs := len(matchHeader) - 1
        for ; matchs >= 0; matchs -= 1 {
-               if !MatchedOne(matchHeader[matchs], r.Header.Get(matchHeader[matchs].Key)) {
+               if !matchHeader[matchs].Match(r.Header.Get(matchHeader[matchs].Key)) {
                        break
                }
        }
        return matchs == -1
 }
 
-func MatchedOne(matchHeader Header, value string) bool {
-       if matchHeader.Value != "" && value != matchHeader.Value {
-               return false
-       }
-       if matchHeader.MatchExp != "" {
-               if regexp, e := regexp.Compile(matchHeader.MatchExp); e != nil || !regexp.MatchString(value) {
-                       return false
+func BodyMatchs(matchBody []Body, r *http.Request) (reader io.ReadCloser, e error) {
+       reader = r.Body
+       for i := 0; i < len(matchBody); i++ {
+               reader, e = matchBody[i].Match(reader)
+               if e != nil {
+                       return
                }
        }
-       return true
+       return
 }
 
 func (t *Back) IsLive() bool {
@@ -309,9 +309,82 @@ func (t *Back) Disable() {
        t.upT = time.Now().Add(time.Second * time.Duration(t.ErrBanSec))
 }
 
+type Matcher struct {
+       MatchHeader []Header `json:"matchHeader"`
+       ReqHeader   []Header `json:"reqHeader"`
+       ResHeader   []Header `json:"resHeader"`
+       ReqBody     []Body   `json:"reqBody"`
+}
+
 type Header struct {
        Action   string `json:"action"`
        Key      string `json:"key"`
        MatchExp string `json:"matchExp"`
        Value    string `json:"value"`
 }
+
+func (t *Header) Match(value string) bool {
+       if t.Value != "" && value != t.Value {
+               return false
+       }
+       if t.MatchExp != "" {
+               if exp, e := regexp.Compile(t.MatchExp); e != nil || !exp.MatchString(value) {
+                       return false
+               }
+       }
+       return true
+}
+
+type Body struct {
+       Action   string `json:"action"`
+       ReqSize  string `json:"reqSize"`
+       MatchExp string `json:"matchExp"`
+}
+
+func (t *Body) Match(r io.ReadCloser) (d io.ReadCloser, err error) {
+       if exp, e := regexp.Compile(t.MatchExp); e == nil {
+               if t.ReqSize == "" {
+                       t.ReqSize = "1M"
+               }
+
+               var (
+                       size, err = humanize.ParseBytes(t.ReqSize)
+                       buf       = make([]byte, size)
+                       n         int
+               )
+
+               if err != nil {
+                       return nil, err
+               }
+
+               for n < int(size) && err == nil {
+                       var nn int
+                       nn, err = r.Read(buf[n:])
+                       n += nn
+               }
+               if n >= int(size) {
+                       return nil, errors.New("body overflow")
+               } else if err != nil && !errors.Is(err, io.EOF) {
+                       return nil, err
+               }
+               buf = buf[:n]
+
+               switch t.Action {
+               case "access":
+                       if !exp.Match(buf) {
+                               return nil, errors.New("body deny")
+                       }
+               case "deny":
+                       if exp.Match(buf) {
+                               return nil, errors.New("body deny")
+                       }
+               }
+
+               return pio.RWC{
+                       R: bytes.NewReader(buf).Read,
+                       C: func() error { return nil },
+               }, nil
+       } else {
+               return nil, e
+       }
+}
diff --git a/main.go b/main.go
index b3d35fbd6c52bedda48696009c595f9422e41ce2..fba1ca6a24abc61b8aaf1b9e2b63904d52a80ef6 100755 (executable)
--- a/main.go
+++ b/main.go
@@ -116,6 +116,7 @@ var (
        ErrReqDoFail       = errors.New("ErrReqDoFail")
        ErrResDoFail       = errors.New("ErrResDoFail")
        ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
+       ErrBodyCheckFail   = errors.New("ErrBodyCheckFail")
 )
 
 func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
@@ -136,7 +137,7 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
 
                url = "http" + url
 
-               for _, v := range chosenBack.ReqHeader {
+               for _, v := range chosenBack.tmp.ReqHeader {
                        if v.Action == `check` {
                                if r.Header.Get(v.Key) != v.Value {
                                        return ErrHeaderCheckFail
@@ -144,12 +145,18 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
                        }
                }
 
-               req, e := http.NewRequestWithContext(ctx, r.Method, url, r.Body)
+               reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+               if e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
+                       return errors.Join(ErrBodyCheckFail, e)
+               }
+
+               req, e := http.NewRequestWithContext(ctx, r.Method, url, reader)
                if e != nil {
                        return errors.Join(ErrReqCreFail, e)
                }
 
-               if e := copyHeader(r.Header, req.Header, chosenBack.ReqHeader); e != nil {
+               if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil {
                        logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
                        return e
                }
@@ -190,7 +197,7 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou
 
        w.Header().Add("_pto_"+cookie, chosenBack.Name)
 
-       if e := copyHeader(resp.Header, w.Header(), chosenBack.ResHeader); e != nil {
+       if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil {
                logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
                return e
        }
@@ -229,6 +236,12 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
                chosenBack = backs[0]
                backs = backs[1:]
 
+               _, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+               if e != nil {
+                       logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
+                       return errors.Join(ErrBodyCheckFail, e)
+               }
+
                url := chosenBack.To
                if chosenBack.PathAdd {
                        url += r.URL.String()
@@ -238,12 +251,11 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
 
                reqHeader := make(http.Header)
 
-               if e := copyHeader(r.Header, reqHeader, chosenBack.ReqHeader); e != nil {
+               if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil {
                        logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
                        return e
                }
 
-               var e error
                conn, resp, e = DialContext(ctx, url, reqHeader)
                if e != nil {
                        chosenBack.Disable()
@@ -278,7 +290,7 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route
        defer conn.Close()
 
        resHeader := make(http.Header)
-       if e := copyHeader(resp.Header, resHeader, chosenBack.ResHeader); e != nil {
+       if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil {
                logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
                return e
        }
@@ -328,7 +340,7 @@ func copyHeader(s, t http.Header, app []Header) error {
        for _, v := range app {
                switch v.Action {
                case `check`:
-                       if !MatchedOne(v, tm[v.Key][0]) {
+                       if !v.Match(tm[v.Key][0]) {
                                return ErrHeaderCheckFail
                        }
                case `replace`: