From 8705b32ebe0de7ec15df08068e722c0e46ead36e Mon Sep 17 00:00:00 2001 From: qydysky Date: Sat, 23 Mar 2024 10:59:52 +0800 Subject: [PATCH] 1 --- README.md | 4 ++++ config.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++------- http.go | 49 +++++++++++++++++++++++++++++++------- main.go | 2 ++ ws.go | 44 +++++++++++++++++++++++++++------- 5 files changed, 143 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 56164a3..b15fc4c 100755 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ config: - *tls*: {} 启用tls - *pub*: string 公钥pem路径 - *key*: string 私钥pem路径 +- version: string 配置版本,当变化时,将重载 - routes: [] 路由 - path: string 路径 - pathAdd: bool 将客户端访问的路径附加在path上 例:/api/req => /ws => /ws/api/req @@ -43,6 +44,7 @@ config: - name: string 后端名称,将在日志中显示 - to: string 后端地址,例`s://www.baidu.com`,会根据客户端自动添加http or ws在地址前 - weight: int 权重,按routes中的全部back的权重比分配,当权重为0时,将停止新请求的进入 + - alwaysUp: bool 总是在线 - setting... setting: @@ -50,6 +52,8 @@ setting: - splicing: int 当客户端支持cookie时,将会固定使用后端多少秒,默认不启用 - errToSec: float64 当后端响应超过(ws则指初次返回时间)指定秒,将会触发errBanSec - errBanSec: int 当后端错误时(指连接失败,不指后端错误响应),将会禁用若干秒 +- insecureSkipVerify: bool 忽略不安全的tls证书 +- verifyPeerCer: string 路径,校验服务器证书,使用intermediate_ca - filiter: - reqUri: 请求后端前,请求路径过滤器 diff --git a/config.go b/config.go index 4abf61a..e2e623a 100755 --- a/config.go +++ b/config.go @@ -3,10 +3,12 @@ package front import ( "context" "crypto/tls" + "encoding/pem" "errors" "fmt" "net" "net/http" + "os" "strings" "sync" "time" @@ -211,9 +213,9 @@ func (t *Config) SwapSign(ctx context.Context, logger Logger) { for i := 0; i < len(t.Routes); i++ { if _, ok := t.routeMap.Load(t.Routes[i].Path); !ok { - routeU(&t.Routes[i], logger) add(t.Routes[i].Path, &t.Routes[i], logger) } + routeU(&t.Routes[i], logger) } } @@ -253,6 +255,7 @@ func (t *Route) SwapSign(add func(string, *Back), del func(string, *Back), logge if _, ok := t.backMap.Load(t.Backs[i].Id()); !ok { add(t.Backs[i].Id(), &t.Backs[i]) } + t.Backs[i].SwapSign(logger) } } @@ -286,13 +289,27 @@ type Back struct { lastResDru time.Duration `json:"-"` resDru time.Duration `json:"-"` - Name string `json:"name"` - To string `json:"to"` - Weight int `json:"weight"` + Name string `json:"name"` + To string `json:"to"` + Weight int `json:"weight"` + AlwaysUp bool `json:"alwaysUp"` Setting } +func (t *Back) SwapSign(logger Logger) { + path := t.VerifyPeerCer + if path == "" { + path = t.route.VerifyPeerCer + } + if path == "" { + t.verifyPeerCerErr = ErrEmptyVerifyPeerCerByte + t.verifyPeerCer = nil + } else { + t.verifyPeerCer, t.verifyPeerCerErr = os.ReadFile(path) + } +} + func (t *Back) Splicing() int { return t.route.Splicing } @@ -313,6 +330,12 @@ func (t *Back) getErrToSec() float64 { return t.ErrToSec } } +func (t *Back) getInsecureSkipVerify() bool { + return t.route.InsecureSkipVerify || t.InsecureSkipVerify +} +func (t *Back) getVerifyPeerCer() (cer []byte, e error) { + return t.verifyPeerCer, t.verifyPeerCerErr +} func (t *Back) getFiliterReqHeader() *filiter.Header { if !t.Filiter.ReqHeader.Valid() { return &t.route.Filiter.ReqHeader @@ -361,12 +384,18 @@ func (t *Back) ed() { } func (t *Back) IsLive() bool { + if t.AlwaysUp { + return true + } t.lock.RLock() defer t.lock.RUnlock() return t.upT.Before(time.Now()) } func (t *Back) Disable() { + if t.AlwaysUp { + return + } tmp := t.getErrBanSec() if tmp == 0 { tmp = 1 @@ -378,9 +407,32 @@ func (t *Back) Disable() { } type Setting struct { - ErrToSec float64 `json:"errToSec"` - Splicing int `json:"splicing"` - ErrBanSec int `json:"errBanSec"` - Filiter filiter.Filiter `json:"filiter"` - Dealer dealer.Dealer `json:"dealer"` + ErrToSec float64 `json:"errToSec"` + Splicing int `json:"splicing"` + ErrBanSec int `json:"errBanSec"` + InsecureSkipVerify bool `json:"insecureSkipVerify"` + VerifyPeerCer string `json:"verifyPeerCer"` + Filiter filiter.Filiter `json:"filiter"` + Dealer dealer.Dealer `json:"dealer"` + verifyPeerCer []byte + verifyPeerCerErr error +} + +var ( + ErrEmptyVerifyPeerCerByte = errors.New("ErrEmptyVerifyPeerCerByte") +) + +func LoadX509PubKey(certPEMBlock []byte) tls.Certificate { + var cert tls.Certificate + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } + } + return cert } diff --git a/http.go b/http.go index b2c3dd8..39b398d 100644 --- a/http.go +++ b/http.go @@ -2,6 +2,8 @@ package front import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -21,13 +23,11 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou logFormat = "%v%v > %v http %v %v %v" ) - for 0 < len(backs) && resp == nil { - chosenBack = backs[0] - backs = backs[1:] - - if !chosenBack.IsLive() { + for i := 0; i < len(backs) && resp == nil; i++ { + if !backs[i].IsLive() { continue } + chosenBack = backs[i] url := chosenBack.To if chosenBack.PathAdd() { @@ -46,14 +46,44 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou return ErrDealReqHeader } + customTransport := http.DefaultTransport.(*http.Transport).Clone() + + customTransport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: chosenBack.getInsecureSkipVerify(), + } + + if cer, err := chosenBack.getVerifyPeerCer(); err == nil { + pool := x509.NewCertPool() + if pool.AppendCertsFromPEM(cer) { + customTransport.TLSClientConfig.InsecureSkipVerify = true + customTransport.TLSClientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) (e error) { + if len(rawCerts) == 0 { + return ErrCerVerify + } + if serCer, err := x509.ParseCertificate(rawCerts[0]); err != nil { + return err + } else if _, err = serCer.Verify(x509.VerifyOptions{Intermediates: pool, Roots: pool}); err != nil { + return err + } + return + } + } else { + logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "Err", ErrCerVerify, time.Since(opT))) + } + } else if err != ErrEmptyVerifyPeerCerByte { + logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "Err", err, time.Since(opT))) + } + client := http.Client{ + Transport: customTransport, CheckRedirect: func(req *http.Request, via []*http.Request) error { return ErrRedirect }, } + resp, e = client.Do(req) if e != nil && !errors.Is(e, ErrRedirect) && !errors.Is(e, context.Canceled) { - logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "BLOCK", e, time.Since(opT))) + logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "Err", e, time.Since(opT))) chosenBack.Disable() resp = nil } @@ -65,11 +95,14 @@ func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, rou } } - if resp == nil { - logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "BLOCK", ErrAllBacksFail, time.Since(opT))) + if chosenBack == nil { return ErrAllBacksFail } + if resp == nil { + return ErrResFail + } + if ok, e := chosenBack.getFiliterResHeader().Match(resp.Header); e != nil { logger.Warn(`W:`, fmt.Sprintf(logFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, "Err", e, time.Since(opT))) } else if !ok { diff --git a/main.go b/main.go index cacde3d..f83a213 100755 --- a/main.go +++ b/main.go @@ -175,6 +175,7 @@ var ( ErrReqCreFail = errors.New("ErrReqCreFail") ErrReqDoFail = errors.New("ErrReqDoFail") ErrResDoFail = errors.New("ErrResDoFail") + ErrResFail = errors.New("ErrResFail") ErrResTO = errors.New("ErrResTO") ErrUriTooLong = errors.New("ErrUriTooLong") ErrPatherCheckFail = errors.New("ErrPatherCheckFail") @@ -185,4 +186,5 @@ var ( ErrNoRoute = errors.New("ErrNoRoute") ErrDealReqHeader = errors.New("ErrDealReqHeader") ErrDealResHeader = errors.New("ErrDealResHeader") + ErrCerVerify = errors.New("ErrCerVerify") ) diff --git a/ws.go b/ws.go index 3177068..4c918f7 100644 --- a/ws.go +++ b/ws.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -32,13 +33,11 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route errFormat = "%v > %v > %v ws %v %v" ) - for 0 < len(backs) && (resp == nil || conn == nil) { - chosenBack = backs[0] - backs = backs[1:] - - if !chosenBack.IsLive() { + for i := 0; i < len(backs) && (resp == nil || conn == nil); i++ { + if !backs[i].IsLive() { continue } + chosenBack = backs[i] url := chosenBack.To if chosenBack.PathAdd() { @@ -54,7 +53,7 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route return ErrDealReqHeader } - conn, resp, e = DialContext(ctx, url, reqHeader) + conn, resp, e = DialContext(ctx, url, reqHeader, chosenBack) if e != nil && !errors.Is(e, context.Canceled) { logger.Warn(`W:`, fmt.Sprintf(errFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) chosenBack.Disable() @@ -71,11 +70,14 @@ func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, route } } - if resp == nil || conn == nil { - logger.Warn(`W:`, fmt.Sprintf(errFormat, chosenBack.route.config.Addr, routePath, chosenBack.Name, ErrBackFail, time.Since(opT))) + if chosenBack == nil { return ErrAllBacksFail } + if resp == nil || conn == nil { + return ErrResFail + } + if pctx.Done(r.Context()) { return context.Canceled } @@ -162,7 +164,7 @@ func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-cha return c } -func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) { +func DialContext(ctx context.Context, urlStr string, requestHeader http.Header, chosenBack *Back) (net.Conn, *http.Response, error) { d := websocket.DefaultDialer challengeKey := requestHeader.Get("Sec-WebSocket-Key") @@ -303,6 +305,30 @@ func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) if cfg.ServerName == "" { cfg.ServerName = hostNoPort } + cfg.InsecureSkipVerify = chosenBack.getInsecureSkipVerify() + + if cer, err := chosenBack.getVerifyPeerCer(); err == nil { + pool := x509.NewCertPool() + if pool.AppendCertsFromPEM(cer) { + cfg.InsecureSkipVerify = true + cfg.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) (e error) { + if len(rawCerts) == 0 { + return ErrCerVerify + } + if serCer, err := x509.ParseCertificate(rawCerts[0]); err != nil { + return err + } else if _, err = serCer.Verify(x509.VerifyOptions{Intermediates: pool, Roots: pool}); err != nil { + return err + } + return + } + } else { + return nil, nil, ErrCerVerify + } + } else if err != ErrEmptyVerifyPeerCerByte { + return nil, nil, err + } + tlsConn := tls.Client(netConn, cfg) netConn = tlsConn -- 2.39.2