From 584cbc3f0b068eb3c31efee6c95e82510290c0e5 Mon Sep 17 00:00:00 2001 From: qydysky Date: Sun, 10 Mar 2024 02:11:40 +0800 Subject: [PATCH] 1 --- README.md | 18 +- config.go | 119 ++++----- go.mod | 12 +- go.sum | 12 +- http.go | 111 ++++++++ main.go | 701 ++++----------------------------------------------- main/main.go | 2 +- ws.go | 525 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 779 insertions(+), 721 deletions(-) create mode 100644 http.go create mode 100644 ws.go diff --git a/README.md b/README.md index 6d6ade1..caa8acd 100755 --- a/README.md +++ b/README.md @@ -36,21 +36,17 @@ config: matcher: -- matchHeader: [] 匹配客户端请求头,只有都匹配才使用此后端, 可以动态增加/删除 - - key: string 要匹配的header名 - - matchExp: string 要匹配的正则式 - - value: string 要匹配的值 - reqHeader: [] 请求后端前,请求头处理器, 可以动态增加/删除 - - action: string 可选check、replace、add、del、set。 + - action: string 可选access、deny、replace、add、del、set。 - key: string 具体处理哪个头 - - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换 - - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。 + - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换 + - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。 - resHeader: [] 返回后端的响应前,请求头处理器, 可以动态增加/删除 - - action: string 可选check、add、del、set。 + - action: string 可选access、deny、add、del、set。 - key: string 具体处理哪个头 - - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换 - - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。 + - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换 + - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。 - reqBody: [] 请求后端前,请求数据过滤器, 可以动态增加/删除 - - action: string 可选access,deny。 + - action: string 可选access、deny。 - reqSize:string 限定请求数据大小,默认为"1M" - matchExp: string access时如不匹配将结束请求。deny时如匹配将结束请求。 diff --git a/config.go b/config.go index d295606..460538d 100755 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "math/rand/v2" + "net" "net/http" "regexp" "sync" @@ -31,6 +32,7 @@ type Config struct { CopyBlocks int `json:"copyBlocks"` BlocksI pslice.BlocksI[byte] `json:"-"` + routeP pweb.WebPath routeMap sync.Map `json:"-"` Routes []Route `json:"routes"` } @@ -39,20 +41,21 @@ func (t *Config) Run(ctx context.Context, logger Logger) { ctx, done := pctx.WithWait(ctx, 0, time.Minute) defer done() - routeP := pweb.WebPath{} - var matchfunc func(path string) (func(w http.ResponseWriter, r *http.Request), bool) switch t.MatchRule { case "all": - matchfunc = routeP.Load + matchfunc = t.routeP.Load default: - matchfunc = routeP.LoadPerfix + matchfunc = t.routeP.LoadPerfix } - httpSer := http.Server{Addr: t.Addr} + httpSer := http.Server{ + Addr: t.Addr, + BaseContext: func(l net.Listener) context.Context { return ctx }, + } if t.TLS.Key != "" && t.TLS.Pub != "" { if cert, e := tls.LoadX509KeyPair(t.TLS.Pub, t.TLS.Key); e != nil { - logger.Error(`E:`, e) + logger.Error(`E:`, fmt.Sprintf("%v %v", t.Addr, e)) } else { httpSer.TLSConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -67,24 +70,29 @@ func (t *Config) Run(ctx context.Context, logger Logger) { t.BlocksI = pslice.NewBlocks[byte](16*1024, t.CopyBlocks) } - syncWeb := pweb.NewSyncMap(&httpSer, &routeP, matchfunc) + syncWeb := pweb.NewSyncMap(&httpSer, &t.routeP, matchfunc) defer syncWeb.Shutdown() - var addRoute = func(k string, route *Route) { - logger.Info(`I:`, "路由加载", k) - t.routeMap.Store(k, route) + t.SwapSign(ctx, logger) + logger.Info(`I:`, fmt.Sprintf("%v running", t.Addr)) + <-ctx.Done() + logger.Info(`I:`, fmt.Sprintf("%v shutdown", t.Addr)) +} - routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) { - ctx1, done1 := pctx.WaitCtx(ctx) - defer done1() +func (t *Config) SwapSign(ctx context.Context, logger Logger) { + var add = func(k string, route *Route, logger Logger) { + route.config = t + logger.Info(`I:`, fmt.Sprintf("%v > %v", t.Addr, k)) + t.routeMap.Store(k, route) - if !HeaderMatchs(route.MatchHeader, r) { + t.routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) { + if !HeaderMatchs(route.ReqHeader, r) { w.WriteHeader(http.StatusNotFound) } var backIs []*Back if t, e := r.Cookie("_psign_" + cookie); e == nil { - if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).MatchHeader, r) { + if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).ReqHeader, r) { backP.(*Back).PathAdd = route.PathAdd backP.(*Back).Splicing = route.Splicing backP.(*Back).tmp.ReqHeader = append(route.ReqHeader, backP.(*Back).ReqHeader...) @@ -105,9 +113,12 @@ func (t *Config) Run(ctx context.Context, logger Logger) { var e error if r.Header.Get("Upgrade") == "websocket" { - e = wsDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI) + e = wsDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI) } else { - e = httpDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI) + e = httpDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI) + } + if e != nil { + w.Header().Add(header+"Error", e.Error()) } if errors.Is(e, ErrHeaderCheckFail) || errors.Is(e, ErrBodyCheckFail) { w.WriteHeader(http.StatusForbidden) @@ -116,25 +127,27 @@ func (t *Config) Run(ctx context.Context, logger Logger) { }) } - var delRoute = func(k string, route *Route) { - logger.Info(`I:`, "路由移除", k) + var del = func(k string, route *Route, logger Logger) { + logger.Info(`I:`, fmt.Sprintf("%v x %v", t.Addr, k)) t.routeMap.Delete(k) - routeP.Store(k, nil) + t.routeP.Store(k, nil) } - t.SwapSign(addRoute, delRoute, logger) - logger.Info(`I:`, "启动完成") - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second * 10): - t.SwapSign(addRoute, delRoute, logger) - } + var routeU = func(route *Route, logger Logger) { + route.SwapSign( + func(k string, b *Back) { + b.route = route + logger.Info(`I:`, fmt.Sprintf("%v > %v > %v", t.Addr, route.Path, b.Name)) + route.backMap.Store(k, b) + }, + func(k string, b *Back) { + logger.Info(`I:`, fmt.Sprintf("%v > %v x %v", t.Addr, route.Path, b.Name)) + route.backMap.Delete(k) + }, + logger, + ) } -} -func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), logger Logger) { t.routeMap.Range(func(key, value any) bool { var exist bool for k := 0; k < len(t.Routes); k++ { @@ -144,34 +157,22 @@ func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), lo } } if !exist { - del(key.(string), value.(*Route)) + del(key.(string), value.(*Route), logger) } return true }) for i := 0; i < len(t.Routes); i++ { if _, ok := t.routeMap.Load(t.Routes[i].Path); !ok { - add(t.Routes[i].Path, &t.Routes[i]) + routeU(&t.Routes[i], logger) + add(t.Routes[i].Path, &t.Routes[i], logger) } } - - for i := 0; i < len(t.Routes); i++ { - t.Routes[i].SwapSign( - func(k string, b *Back) { - logger.Info(`I:`, "后端加载", t.Routes[i].Path, b.Name) - t.Routes[i].backMap.Store(k, b) - }, - func(k string, b *Back) { - logger.Info(`I:`, "后端移除", t.Routes[i].Path, b.Name) - t.Routes[i].backMap.Delete(k) - }, - logger, - ) - } } type Route struct { - Path string `json:"path"` + config *Config `json:"-"` + Path string `json:"path"` Splicing int `json:"splicing"` PathAdd bool `json:"pathAdd"` @@ -237,7 +238,7 @@ func (t *Route) SwapSign(add func(string, *Back), del func(string, *Back), logge func (t *Route) FiliterBackByRequest(r *http.Request) []*Back { var backLink []*Back for i := 0; i < len(t.Backs); i++ { - if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].MatchHeader, r) { + if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].ReqHeader, r) { t.Backs[i].PathAdd = t.PathAdd t.Backs[i].Splicing = t.Splicing t.Backs[i].tmp.ReqHeader = append(t.ReqHeader, t.Backs[i].ReqHeader...) @@ -255,8 +256,9 @@ func (t *Route) FiliterBackByRequest(r *http.Request) []*Back { } type Back struct { - lock sync.RWMutex `json:"-"` - upT time.Time `json:"-"` + route *Route `json:"-"` + lock sync.RWMutex `json:"-"` + upT time.Time `json:"-"` Name string `json:"name"` To string `json:"to"` @@ -310,10 +312,9 @@ func (t *Back) Disable() { } type Matcher struct { - MatchHeader []Header `json:"matchHeader"` - ReqHeader []Header `json:"reqHeader"` - ResHeader []Header `json:"resHeader"` - ReqBody []Body `json:"reqBody"` + ReqHeader []Header `json:"reqHeader"` + ResHeader []Header `json:"resHeader"` + ReqBody []Body `json:"reqBody"` } type Header struct { @@ -324,15 +325,15 @@ type Header struct { } func (t *Header) Match(value string) bool { - if t.Value != "" && value != t.Value { - return false + if t.Action != "access" && t.Action != "deny" { + return true } if t.MatchExp != "" { if exp, e := regexp.Compile(t.MatchExp); e != nil || !exp.MatchString(value) { - return false + return t.Action == "deny" } } - return true + return t.Action == "access" } type Body struct { diff --git a/go.mod b/go.mod index 46c2258..77a3d61 100755 --- a/go.mod +++ b/go.mod @@ -1,24 +1,30 @@ module github.com/qydysky/front -go 1.21 +go 1.22 require ( + github.com/dustin/go-humanize v1.0.1 github.com/gorilla/websocket v1.5.1 - github.com/qydysky/part v0.28.20240114140844 + github.com/qydysky/part v0.28.20240309172046 golang.org/x/net v0.18.0 ) require ( + github.com/andybalholm/brotli v1.0.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/google/uuid v1.4.0 // indirect + github.com/klauspost/compress v1.17.3 // indirect + github.com/miekg/dns v1.1.57 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/shirou/gopsutil v3.21.11+incompatible // indirect + github.com/thedevsaddam/gojsonq/v2 v2.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect + golang.org/x/mod v0.14.0 // indirect golang.org/x/sys v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.15.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 50a87c8..1c27d4e 100755 --- a/go.sum +++ b/go.sum @@ -13,20 +13,28 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= +github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= +github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/qydysky/part v0.28.20240114140844 h1:9d0NPLx2CHKFca5Q6brECXeCjj0VXXaXmZRcJo6W85w= -github.com/qydysky/part v0.28.20240114140844/go.mod h1:NyKyjpBCSjcHtKlC+fL5lCidm57UCnwEgufiBDs5yxA= +github.com/qydysky/part v0.28.20240309114649 h1:b82WHpgNecfv/UX4makU9EIHJ4nLAP9dnXCb8S+sJt4= +github.com/qydysky/part v0.28.20240309114649/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c= +github.com/qydysky/part v0.28.20240309172046 h1:aw2Dv8VaP0p+IMkwJQlCNaz0ccJ6l8YUhu+y39kvQgU= +github.com/qydysky/part v0.28.20240309172046/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/thedevsaddam/gojsonq/v2 v2.5.2 h1:CoMVaYyKFsVj6TjU6APqAhAvC07hTI6IQen8PHzHYY0= +github.com/thedevsaddam/gojsonq/v2 v2.5.2/go.mod h1:bv6Xa7kWy82uT0LnXPE2SzGqTj33TAEeR560MdJkiXs= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/http.go b/http.go new file mode 100644 index 0000000..51b1f58 --- /dev/null +++ b/http.go @@ -0,0 +1,111 @@ +package front + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "time" + _ "unsafe" + + pslice "github.com/qydysky/part/slice" +) + +func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { + + var ( + opT = time.Now() + resp *http.Response + chosenBack *Back + ) + + for 0 < len(backs) && resp == nil { + chosenBack = backs[0] + backs = backs[1:] + + url := chosenBack.To + if chosenBack.PathAdd { + url += r.URL.String() + } + + url = "http" + url + + reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r) + if e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return errors.Join(ErrBodyCheckFail, e) + } + + req, e := http.NewRequestWithContext(ctx, r.Method, url, reader) + if e != nil { + return errors.Join(ErrReqCreFail, e) + } + + if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return e + } + + client := http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return ErrRedirect + }, + } + resp, e = client.Do(req) + if e != nil && !errors.Is(e, ErrRedirect) && !errors.Is(e, context.Canceled) { + chosenBack.Disable() + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + } + } + + if 0 == len(backs) && resp == nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT))) + return errors.New("全部后端故障") + } else if resp == nil { + return errors.New("后端故障") + } + + logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v http ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT))) + + { + cookie := &http.Cookie{ + Name: "_psign_" + cookie, + Value: chosenBack.Id(), + MaxAge: chosenBack.Splicing, + Path: "/", + } + if validCookieDomain(r.Host) { + cookie.Domain = r.Host + } + w.Header().Add("Set-Cookie", (cookie).String()) + } + + w.Header().Add(header+"Info", cookie+";"+chosenBack.Name) + + if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return e + } + + w.WriteHeader(resp.StatusCode) + + if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 { + return nil + } + + defer resp.Body.Close() + if tmpbuf, put, e := blocksi.Get(); e != nil { + logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + chosenBack.Disable() + return errors.Join(ErrCopy, e) + } else { + defer put() + if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil { + logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + chosenBack.Disable() + return errors.Join(ErrCopy, e) + } + } + return nil +} diff --git a/main.go b/main.go index fba1ca6..5f0e1d0 100755 --- a/main.go +++ b/main.go @@ -1,29 +1,20 @@ package front import ( - "bufio" - "bytes" "context" - "crypto/tls" + "crypto/md5" "encoding/json" "errors" "fmt" "io" - "log" - "net" "net/http" - "net/http/httptrace" - "net/url" "regexp" "strings" "time" _ "unsafe" - "github.com/gorilla/websocket" pctx "github.com/qydysky/part/ctx" - pslice "github.com/qydysky/part/slice" pweb "github.com/qydysky/part/web" - "golang.org/x/net/proxy" ) type Logger interface { @@ -42,10 +33,17 @@ func validCookieDomain(v string) bool // 加载 func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error { - if e := loadConfig(buf, configF, configS); e != nil { + var oldBufMd5 string + + if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil { logger.Error(`E:`, "配置加载", e) return e + } else { + oldBufMd5 = bufMd5 } + + logger.Info(`I:`, "配置更新", oldBufMd5[:5]) + // 定时加载config go func() { ctx1, done1 := pctx.WaitCtx(ctx) @@ -53,8 +51,11 @@ func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config for { select { case <-time.After(time.Second * 5): - if e := loadConfig(buf, configF, configS); e != nil { + if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil { logger.Error(`E:`, "配置加载", e) + } else if bufMd5 != oldBufMd5 { + oldBufMd5 = bufMd5 + logger.Info(`I:`, "配置更新", oldBufMd5[:5]) } case <-ctx1.Done(): return @@ -85,239 +86,32 @@ func Test(ctx context.Context, port int, logger Logger) { <-ctx1.Done() } -var cookie = fmt.Sprintf("%p", &struct{}{}) - -func loadConfig(buf []byte, configF File, configS *[]Config) error { +func loadConfig(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) (md5k string, e error) { + defer func() { + if err := recover(); err != nil { + logger.Error(`E:`, err) + e = errors.New("read panic") + } + }() if i, e := configF.Read(buf); e != nil && !errors.Is(e, io.EOF) { - return e + return "", e } else if i == cap(buf) { - return errors.New(`buf full`) + return "", errors.New(`buf full`) } else { - for i := 0; i < len(*configS); i++ { - (*configS)[i].lock.Lock() - defer (*configS)[i].lock.Unlock() - } - if e := json.Unmarshal(buf[:i], configS); e != nil { - return e - } - } - return nil -} + w := md5.New() + w.Write(buf[:i]) + md5k = fmt.Sprintf("%x", w.Sum(nil)) -//go:linkname nanotime1 runtime.nanotime1 -func nanotime1() int64 - -var ( - ErrRedirect = errors.New("ErrRedirect") - ErrNoHttp = errors.New("ErrNoHttp") - ErrNoWs = errors.New("ErrNoWs") - ErrCopy = errors.New("ErrCopy") - ErrReqCreFail = errors.New("ErrReqCreFail") - ErrReqDoFail = errors.New("ErrReqDoFail") - ErrResDoFail = errors.New("ErrResDoFail") - ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail") - ErrBodyCheckFail = errors.New("ErrBodyCheckFail") -) - -func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { - var ( - opT = time.Now() - resp *http.Response - chosenBack *Back - ) - - for 0 < len(backs) && resp == nil { - chosenBack = backs[0] - backs = backs[1:] - - url := chosenBack.To - if chosenBack.PathAdd { - url += r.URL.String() - } - - url = "http" + url - - for _, v := range chosenBack.tmp.ReqHeader { - if v.Action == `check` { - if r.Header.Get(v.Key) != v.Value { - return ErrHeaderCheckFail - } - } - } - - reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r) - if e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail)) - return errors.Join(ErrBodyCheckFail, e) - } - - req, e := http.NewRequestWithContext(ctx, r.Method, url, reader) - if e != nil { - return errors.Join(ErrReqCreFail, e) - } - - if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - return e - } - - client := http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return ErrRedirect - }, - } - resp, e = client.Do(req) - if e != nil && !errors.Is(e, ErrRedirect) { - chosenBack.Disable() - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - } - } - - if 0 == len(backs) && resp == nil { - logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name)) - return errors.New("全部后端故障") - } else if resp == nil { - return errors.New("后端故障") - } - - logger.Debug(`T:`, fmt.Sprintf("http %s=>%s %v", routePath, chosenBack.Name, time.Since(opT))) - - { - cookie := &http.Cookie{ - Name: "_psign_" + cookie, - Value: chosenBack.Id(), - MaxAge: chosenBack.Splicing, - Path: "/", - } - if validCookieDomain(r.Host) { - cookie.Domain = r.Host - } - w.Header().Add("Set-Cookie", (cookie).String()) - } - - w.Header().Add("_pto_"+cookie, chosenBack.Name) - - if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - return e - } - - w.WriteHeader(resp.StatusCode) - - if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 { - return nil - } - - defer resp.Body.Close() - if tmpbuf, put, e := blocksi.Get(); e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - chosenBack.Disable() - return errors.Join(ErrCopy, e) - } else { - defer put() - if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil { - logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - chosenBack.Disable() - return errors.Join(ErrCopy, e) - } - } - return nil -} - -func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { - var ( - opT = time.Now() - resp *http.Response - conn net.Conn - chosenBack *Back - ) - - for 0 < len(backs) && (resp == nil || conn == nil) { - chosenBack = backs[0] - backs = backs[1:] - - _, e := BodyMatchs(chosenBack.tmp.ReqBody, r) - if e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail)) - return errors.Join(ErrBodyCheckFail, e) - } - - url := chosenBack.To - if chosenBack.PathAdd { - url += r.URL.String() - } - - url = "ws" + url - - reqHeader := make(http.Header) - - if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - return e - } - - conn, resp, e = DialContext(ctx, url, reqHeader) - if e != nil { - chosenBack.Disable() - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - } - } - - if 0 == len(backs) && (resp == nil || conn == nil) { - logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name)) - return errors.New("全部后端故障") - } else if resp == nil || conn == nil { - return errors.New("后端故障") - } - - logger.Debug(`T:`, fmt.Sprintf("ws %s=>%s %v", routePath, chosenBack.Name, time.Since(opT))) - - { - cookie := &http.Cookie{ - Name: "_psign_" + cookie, - Value: chosenBack.Id(), - MaxAge: chosenBack.Splicing, - Path: "/", - } - if validCookieDomain(r.Host) { - cookie.Domain = r.Host + if e := json.Unmarshal(buf[:i], configS); e != nil { + return md5k, e } - w.Header().Add("Set-Cookie", (cookie).String()) - } - - w.Header().Add("_pto_"+cookie, chosenBack.Name) - - defer conn.Close() - - resHeader := make(http.Header) - if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil { - logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e)) - return e - } - - if req, e := Upgrade(w, r, resHeader); e != nil { - return errors.Join(ErrResDoFail, e) - } else { - defer req.Close() - - select { - case e := <-copyWsMsg(req, conn, blocksi): - if e != nil { - chosenBack.Disable() - logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, chosenBack.Name, e)) - return errors.Join(ErrCopy, e) - } - case e := <-copyWsMsg(conn, req, blocksi): - if e != nil { - chosenBack.Disable() - logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, chosenBack.Name, e)) - return errors.Join(ErrCopy, e) - } - case <-ctx.Done(): + for i := 0; i < len(*configS); i++ { + (*configS)[i].lock.Lock() + (*configS)[i].SwapSign(ctx, logger) + (*configS)[i].lock.Unlock() } - - return nil } + return md5k, nil } func copyHeader(s, t http.Header, app []Header) error { @@ -339,8 +133,18 @@ func copyHeader(s, t http.Header, app []Header) error { } for _, v := range app { switch v.Action { - case `check`: - if !v.Match(tm[v.Key][0]) { + case `deny`: + if va, ok := tm[v.Key]; ok && len(va) != 0 { + if exp, e := regexp.Compile(v.MatchExp); e == nil && exp.MatchString(va[0]) { + return ErrHeaderCheckFail + } + } + case `access`: + if va, ok := tm[v.Key]; ok && len(va) != 0 { + if exp, e := regexp.Compile(v.MatchExp); e != nil || !exp.MatchString(va[0]) { + return ErrHeaderCheckFail + } + } else { return ErrHeaderCheckFail } case `replace`: @@ -359,409 +163,16 @@ func copyHeader(s, t http.Header, app []Header) error { return nil } -func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error { - c := make(chan error, 1) - go func() { - if tmpbuf, put, e := blocksi.Get(); e != nil { - c <- e - } else { - defer put() - _, e := io.CopyBuffer(dst, src, tmpbuf) - c <- e - } - }() - return c -} - -func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) { - d := websocket.DefaultDialer - - challengeKey := requestHeader.Get("Sec-WebSocket-Key") - - u, err := url.Parse(urlStr) - if err != nil { - return nil, nil, err - } - - switch u.Scheme { - case "ws": - u.Scheme = "http" - case "wss": - u.Scheme = "https" - default: - return nil, nil, errMalformedURL - } - - if u.User != nil { - // User name and password are not allowed in websocket URIs. - return nil, nil, errMalformedURL - } - - req := &http.Request{ - Method: http.MethodGet, - URL: u, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Host: u.Host, - } - req = req.WithContext(ctx) - - // Set the request headers using the capitalization for names and values in - // RFC examples. Although the capitalization shouldn't matter, there are - // servers that depend on it. The Header.Set method is not used because the - // method canonicalizes the header names. - for k, vs := range requestHeader { - req.Header[k] = vs - } - - if d.HandshakeTimeout != 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) - defer cancel() - } - - // Get network dial function. - var netDial func(network, add string) (net.Conn, error) - - switch u.Scheme { - case "http": - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) - } - } else if d.NetDial != nil { - netDial = d.NetDial - } - case "https": - if d.NetDialTLSContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialTLSContext(ctx, network, addr) - } - } else if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) - } - } else if d.NetDial != nil { - netDial = d.NetDial - } - default: - return nil, nil, errMalformedURL - } - - if netDial == nil { - netDialer := &net.Dialer{} - netDial = func(network, addr string) (net.Conn, error) { - return netDialer.DialContext(ctx, network, addr) - } - } - - // If needed, wrap the dial function to set the connection deadline. - if deadline, ok := ctx.Deadline(); ok { - forwardDial := netDial - netDial = func(network, addr string) (net.Conn, error) { - c, err := forwardDial(network, addr) - if err != nil { - return nil, err - } - err = c.SetDeadline(deadline) - if err != nil { - if err := c.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } - return c, nil - } - } - - // If needed, wrap the dial function to connect through a proxy. - if d.Proxy != nil { - proxyURL, err := d.Proxy(req) - if err != nil { - return nil, nil, err - } - if proxyURL != nil { - dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial)) - if err != nil { - return nil, nil, err - } - netDial = dialer.Dial - } - } - - hostPort, hostNoPort := hostPortNoPort(u) - trace := httptrace.ContextClientTrace(ctx) - if trace != nil && trace.GetConn != nil { - trace.GetConn(hostPort) - } - - netConn, err := netDial("tcp", hostPort) - if err != nil { - return nil, nil, err - } - if trace != nil && trace.GotConn != nil { - trace.GotConn(httptrace.GotConnInfo{ - Conn: netConn, - }) - } - - if u.Scheme == "https" && d.NetDialTLSContext == nil { - // If NetDialTLSContext is set, assume that the TLS handshake has already been done - - cfg := cloneTLSConfig(d.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = hostNoPort - } - tlsConn := tls.Client(netConn, cfg) - netConn = tlsConn - - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(ctx, tlsConn, cfg) - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } - - if err != nil { - return nil, nil, err - } - } - - var br *bufio.Reader - if br == nil { - if d.ReadBufferSize == 0 { - d.ReadBufferSize = defaultReadBufferSize - } else if d.ReadBufferSize < maxControlFramePayloadSize { - // must be large enough for control frame - d.ReadBufferSize = maxControlFramePayloadSize - } - br = bufio.NewReaderSize(netConn, d.ReadBufferSize) - } - - if err := req.Write(netConn); err != nil { - return nil, nil, err - } - - if trace != nil && trace.GotFirstResponseByte != nil { - if peek, err := br.Peek(1); err == nil && len(peek) == 1 { - trace.GotFirstResponseByte() - } - } - - resp, err := http.ReadResponse(br, req) - if err != nil { - if d.TLSClientConfig != nil { - for _, proto := range d.TLSClientConfig.NextProtos { - if proto != "http/1.1" { - return nil, nil, fmt.Errorf( - "websocket: protocol %q was given but is not supported;"+ - "sharing tls.Config with net/http Transport can cause this error: %w", - proto, err, - ) - } - } - } - return nil, nil, err - } - - if d.Jar != nil { - if rc := resp.Cookies(); len(rc) > 0 { - d.Jar.SetCookies(u, rc) - } - } - - if resp.StatusCode != 101 || - !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || - !tokenListContainsValue(resp.Header, "Connection", "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { - // Before closing the network connection on return from this - // function, slurp up some of the response to aid application - // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) - log.Default().Println(resp.StatusCode, resp.Header) - return nil, resp, websocket.ErrBadHandshake - } - - resp.Body = io.NopCloser(bytes.NewReader([]byte{})) - - if err := netConn.SetDeadline(time.Time{}); err != nil { - return nil, nil, err - } - return netConn, resp, nil -} - -type netDialerFunc func(network, addr string) (net.Conn, error) - -func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { - return fn(network, addr) -} - -//go:linkname doHandshake github.com/gorilla/websocket.doHandshake -func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error - -//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig -func cloneTLSConfig(cfg *tls.Config) *tls.Config - -//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort -func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) - -//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL -var errMalformedURL error - -//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression -var errInvalidCompression error - -//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey -func generateChallengeKey() (string, error) - -//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue -func tokenListContainsValue(header http.Header, name string, value string) bool - -//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError -// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error) - -//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin -func checkSameOrigin(r *http.Request) bool - -//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey -func isValidChallengeKey(s string) bool - -//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol -func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string - -//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions -func parseExtensions(header http.Header) []map[string]string - -//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize -func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int - -//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer -func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte - -//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey -func computeAcceptKey(challengeKey string) string - -const ( - maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask - defaultReadBufferSize = 4096 - defaultWriteBufferSize = 4096 - maxControlFramePayloadSize = 125 +var cookie = fmt.Sprintf("%p", &struct{}{}) +var header = "X-Front-" +var ( + ErrRedirect = errors.New("ErrRedirect") + ErrNoHttp = errors.New("ErrNoHttp") + ErrNoWs = errors.New("ErrNoWs") + ErrCopy = errors.New("ErrCopy") + ErrReqCreFail = errors.New("ErrReqCreFail") + ErrReqDoFail = errors.New("ErrReqDoFail") + ErrResDoFail = errors.New("ErrResDoFail") + ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail") + ErrBodyCheckFail = errors.New("ErrBodyCheckFail") ) - -func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) { - u := &websocket.Upgrader{} - h, ok := w.(http.Hijacker) - if !ok { - return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") - } - var brw *bufio.ReadWriter - netConn, brw, err := h.Hijack() - if err != nil { - return returnError(u, w, r, http.StatusInternalServerError, err.Error()) - } - - if brw.Reader.Buffered() > 0 { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, errors.New("websocket: client sent data before handshake is complete") - } - - buf := bufioWriterBuffer(netConn, brw.Writer) - - var writeBuf []byte - if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { - // Reuse hijacked write buffer as connection buffer. - writeBuf = buf - } else { - if u.WriteBufferSize <= 0 { - u.WriteBufferSize = defaultWriteBufferSize - } - u.WriteBufferSize += maxFrameHeaderSize - if writeBuf == nil && u.WriteBufferPool == nil { - writeBuf = make([]byte, u.WriteBufferSize) - } - } - - // Use larger of hijacked buffer and connection write buffer for header. - p := buf - if len(writeBuf) > len(p) { - p = writeBuf - } - p = p[:0] - - p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...) - for k, vs := range responseHeader { - for _, v := range vs { - p = append(p, k...) - p = append(p, ": "...) - for i := 0; i < len(v); i++ { - b := v[i] - if b <= 31 { - // prevent response splitting. - b = ' ' - } - p = append(p, b) - } - p = append(p, "\r\n"...) - } - } - p = append(p, "\r\n"...) - - // Clear deadlines set by HTTP server. - if err := netConn.SetDeadline(time.Time{}); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } - - if u.HandshakeTimeout > 0 { - if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } - } - if _, err = netConn.Write(p); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } - if u.HandshakeTimeout > 0 { - if err := netConn.SetWriteDeadline(time.Time{}); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } - } - - return netConn, nil -} - -func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) { - err := HandshakeError{message: reason} - if u.Error != nil { - u.Error(w, r, status, err) - } else { - w.Header().Set("Sec-Websocket-Version", "13") - http.Error(w, http.StatusText(status), status) - } - return nil, err -} - -type HandshakeError struct { - message string -} - -func (t HandshakeError) Error() string { - return t.message -} diff --git a/main/main.go b/main/main.go index 5508203..bcea759 100755 --- a/main/main.go +++ b/main/main.go @@ -74,7 +74,7 @@ func main() { go pfront.Test(ctx, *testP, logger.Base("测试")) for i := 0; i < len(configS); i++ { - go configS[i].Run(ctx, logger.Base(configS[i].Addr)) + go configS[i].Run(ctx, logger) } // ctrl+c退出 diff --git a/ws.go b/ws.go new file mode 100644 index 0000000..061269c --- /dev/null +++ b/ws.go @@ -0,0 +1,525 @@ +package front + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "time" + _ "unsafe" + + "github.com/gorilla/websocket" + pslice "github.com/qydysky/part/slice" + "golang.org/x/net/proxy" +) + +func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error { + var ( + opT = time.Now() + resp *http.Response + conn net.Conn + chosenBack *Back + ) + + for 0 < len(backs) && (resp == nil || conn == nil) { + chosenBack = backs[0] + backs = backs[1:] + + _, e := BodyMatchs(chosenBack.tmp.ReqBody, r) + if e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return errors.Join(ErrBodyCheckFail, e) + } + + url := chosenBack.To + if chosenBack.PathAdd { + url += r.URL.String() + } + + url = "ws" + url + + reqHeader := make(http.Header) + + if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return e + } + + conn, resp, e = DialContext(ctx, url, reqHeader) + if e != nil && !errors.Is(e, context.Canceled) { + chosenBack.Disable() + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + } + } + + if 0 == len(backs) && (resp == nil || conn == nil) { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT))) + return errors.New("全部后端故障") + } else if resp == nil || conn == nil { + return errors.New("后端故障") + } + + logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v ws ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT))) + + { + cookie := &http.Cookie{ + Name: "_psign_" + cookie, + Value: chosenBack.Id(), + MaxAge: chosenBack.Splicing, + Path: "/", + } + if validCookieDomain(r.Host) { + cookie.Domain = r.Host + } + w.Header().Add("Set-Cookie", (cookie).String()) + } + + w.Header().Add(header+"Info", cookie+";"+chosenBack.Name) + + defer conn.Close() + + resHeader := make(http.Header) + if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil { + logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return e + } + + if req, e := Upgrade(w, r, resHeader); e != nil { + return errors.Join(ErrResDoFail, e) + } else { + defer req.Close() + + select { + case e := <-copyWsMsg(req, conn, blocksi): + if e != nil { + chosenBack.Disable() + logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return errors.Join(ErrCopy, e) + } + case e := <-copyWsMsg(conn, req, blocksi): + if e != nil { + chosenBack.Disable() + logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT))) + return errors.Join(ErrCopy, e) + } + case <-ctx.Done(): + } + + return nil + } +} + +func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error { + c := make(chan error, 1) + go func() { + if tmpbuf, put, e := blocksi.Get(); e != nil { + c <- e + } else { + defer put() + _, e := io.CopyBuffer(dst, src, tmpbuf) + c <- e + } + }() + return c +} + +func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) { + d := websocket.DefaultDialer + + challengeKey := requestHeader.Get("Sec-WebSocket-Key") + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + for k, vs := range requestHeader { + req.Header[k] = vs + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + // Get network dial function. + var netDial func(network, add string) (net.Conn, error) + + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } + } + + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + if err := c.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial("tcp", hostPort) + if err != nil { + return nil, nil, err + } + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + + if err != nil { + return nil, nil, err + } + } + + var br *bufio.Reader + if br == nil { + if d.ReadBufferSize == 0 { + d.ReadBufferSize = defaultReadBufferSize + } else if d.ReadBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + d.ReadBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(netConn, d.ReadBufferSize) + } + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(br, req) + if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + log.Default().Println(resp.StatusCode, resp.Header) + return nil, resp, websocket.ErrBadHandshake + } + + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) + + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, nil, err + } + return netConn, resp, nil +} + +type netDialerFunc func(network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(network, addr) +} + +//go:linkname doHandshake github.com/gorilla/websocket.doHandshake +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error + +//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig +func cloneTLSConfig(cfg *tls.Config) *tls.Config + +//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) + +//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL +var errMalformedURL error + +//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression +var errInvalidCompression error + +//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey +func generateChallengeKey() (string, error) + +//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue +func tokenListContainsValue(header http.Header, name string, value string) bool + +//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError +// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error) + +//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin +func checkSameOrigin(r *http.Request) bool + +//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey +func isValidChallengeKey(s string) bool + +//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol +func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string + +//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions +func parseExtensions(header http.Header) []map[string]string + +//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize +func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int + +//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer +func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte + +//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey +func computeAcceptKey(challengeKey string) string + +const ( + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + maxControlFramePayloadSize = 125 +) + +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) { + u := &websocket.Upgrader{} + h, ok := w.(http.Hijacker) + if !ok { + return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") + } + var brw *bufio.ReadWriter + netConn, brw, err := h.Hijack() + if err != nil { + return returnError(u, w, r, http.StatusInternalServerError, err.Error()) + } + + if brw.Reader.Buffered() > 0 { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, errors.New("websocket: client sent data before handshake is complete") + } + + buf := bufioWriterBuffer(netConn, brw.Writer) + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } else { + if u.WriteBufferSize <= 0 { + u.WriteBufferSize = defaultWriteBufferSize + } + u.WriteBufferSize += maxFrameHeaderSize + if writeBuf == nil && u.WriteBufferPool == nil { + writeBuf = make([]byte, u.WriteBufferSize) + } + } + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(writeBuf) > len(p) { + p = writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...) + for k, vs := range responseHeader { + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + } + if _, err = netConn.Write(p); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } + } + + return netConn, nil +} + +func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) { + err := HandshakeError{message: reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +type HandshakeError struct { + message string +} + +func (t HandshakeError) Error() string { + return t.message +} -- 2.39.2