matcher:
-- matchHeader: [] 匹配客户端请求头,只有都匹配才使用此后端, 可以动态增加/删除
- - key: string 要匹配的header名
- - matchExp: string 要匹配的正则式
- - value: string 要匹配的值
- reqHeader: [] 请求后端前,请求头处理器, 可以动态增加/删除
- - action: string 可选check、replace、add、del、set。
+ - action: string 可选access、deny、replace、add、del、set。
- key: string 具体处理哪个头
- - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
- - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+ - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换
+ - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
- resHeader: [] 返回后端的响应前,请求头处理器, 可以动态增加/删除
- - action: string 可选check、add、del、set。
+ - action: string 可选access、deny、add、del、set。
- key: string 具体处理哪个头
- - matchExp: string check时,如不匹配将结束请求。replace时结合value进行替换
- - value: string check时,如不匹配将结束请求。replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
+ - matchExp: string access时不匹配将结束请求。deny时匹配将结束请求。replace时结合value进行替换
+ - value: string replace时结合matchExp进行替换。add时将附加值。set时将覆盖值。
- reqBody: [] 请求后端前,请求数据过滤器, 可以动态增加/删除
- - action: string 可选access,deny。
+ - action: string 可选access、deny。
- reqSize:string 限定请求数据大小,默认为"1M"
- matchExp: string access时如不匹配将结束请求。deny时如匹配将结束请求。
"fmt"
"io"
"math/rand/v2"
+ "net"
"net/http"
"regexp"
"sync"
CopyBlocks int `json:"copyBlocks"`
BlocksI pslice.BlocksI[byte] `json:"-"`
+ routeP pweb.WebPath
routeMap sync.Map `json:"-"`
Routes []Route `json:"routes"`
}
ctx, done := pctx.WithWait(ctx, 0, time.Minute)
defer done()
- routeP := pweb.WebPath{}
-
var matchfunc func(path string) (func(w http.ResponseWriter, r *http.Request), bool)
switch t.MatchRule {
case "all":
- matchfunc = routeP.Load
+ matchfunc = t.routeP.Load
default:
- matchfunc = routeP.LoadPerfix
+ matchfunc = t.routeP.LoadPerfix
}
- httpSer := http.Server{Addr: t.Addr}
+ httpSer := http.Server{
+ Addr: t.Addr,
+ BaseContext: func(l net.Listener) context.Context { return ctx },
+ }
if t.TLS.Key != "" && t.TLS.Pub != "" {
if cert, e := tls.LoadX509KeyPair(t.TLS.Pub, t.TLS.Key); e != nil {
- logger.Error(`E:`, e)
+ logger.Error(`E:`, fmt.Sprintf("%v %v", t.Addr, e))
} else {
httpSer.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
t.BlocksI = pslice.NewBlocks[byte](16*1024, t.CopyBlocks)
}
- syncWeb := pweb.NewSyncMap(&httpSer, &routeP, matchfunc)
+ syncWeb := pweb.NewSyncMap(&httpSer, &t.routeP, matchfunc)
defer syncWeb.Shutdown()
- var addRoute = func(k string, route *Route) {
- logger.Info(`I:`, "路由加载", k)
- t.routeMap.Store(k, route)
+ t.SwapSign(ctx, logger)
+ logger.Info(`I:`, fmt.Sprintf("%v running", t.Addr))
+ <-ctx.Done()
+ logger.Info(`I:`, fmt.Sprintf("%v shutdown", t.Addr))
+}
- routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
- ctx1, done1 := pctx.WaitCtx(ctx)
- defer done1()
+func (t *Config) SwapSign(ctx context.Context, logger Logger) {
+ var add = func(k string, route *Route, logger Logger) {
+ route.config = t
+ logger.Info(`I:`, fmt.Sprintf("%v > %v", t.Addr, k))
+ t.routeMap.Store(k, route)
- if !HeaderMatchs(route.MatchHeader, r) {
+ t.routeP.Store(route.Path, func(w http.ResponseWriter, r *http.Request) {
+ if !HeaderMatchs(route.ReqHeader, r) {
w.WriteHeader(http.StatusNotFound)
}
var backIs []*Back
if t, e := r.Cookie("_psign_" + cookie); e == nil {
- if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).MatchHeader, r) {
+ if backP, ok := route.backMap.Load(t.Value); ok && backP.(*Back).IsLive() && HeaderMatchs(backP.(*Back).ReqHeader, r) {
backP.(*Back).PathAdd = route.PathAdd
backP.(*Back).Splicing = route.Splicing
backP.(*Back).tmp.ReqHeader = append(route.ReqHeader, backP.(*Back).ReqHeader...)
var e error
if r.Header.Get("Upgrade") == "websocket" {
- e = wsDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI)
+ e = wsDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI)
} else {
- e = httpDealer(ctx1, w, r, route.Path, backIs, logger, t.BlocksI)
+ e = httpDealer(r.Context(), w, r, route.Path, backIs, logger, t.BlocksI)
+ }
+ if e != nil {
+ w.Header().Add(header+"Error", e.Error())
}
if errors.Is(e, ErrHeaderCheckFail) || errors.Is(e, ErrBodyCheckFail) {
w.WriteHeader(http.StatusForbidden)
})
}
- var delRoute = func(k string, route *Route) {
- logger.Info(`I:`, "路由移除", k)
+ var del = func(k string, route *Route, logger Logger) {
+ logger.Info(`I:`, fmt.Sprintf("%v x %v", t.Addr, k))
t.routeMap.Delete(k)
- routeP.Store(k, nil)
+ t.routeP.Store(k, nil)
}
- t.SwapSign(addRoute, delRoute, logger)
- logger.Info(`I:`, "启动完成")
- for {
- select {
- case <-ctx.Done():
- return
- case <-time.After(time.Second * 10):
- t.SwapSign(addRoute, delRoute, logger)
- }
+ var routeU = func(route *Route, logger Logger) {
+ route.SwapSign(
+ func(k string, b *Back) {
+ b.route = route
+ logger.Info(`I:`, fmt.Sprintf("%v > %v > %v", t.Addr, route.Path, b.Name))
+ route.backMap.Store(k, b)
+ },
+ func(k string, b *Back) {
+ logger.Info(`I:`, fmt.Sprintf("%v > %v x %v", t.Addr, route.Path, b.Name))
+ route.backMap.Delete(k)
+ },
+ logger,
+ )
}
-}
-func (t *Config) SwapSign(add func(string, *Route), del func(string, *Route), logger Logger) {
t.routeMap.Range(func(key, value any) bool {
var exist bool
for k := 0; k < len(t.Routes); k++ {
}
}
if !exist {
- del(key.(string), value.(*Route))
+ del(key.(string), value.(*Route), logger)
}
return true
})
for i := 0; i < len(t.Routes); i++ {
if _, ok := t.routeMap.Load(t.Routes[i].Path); !ok {
- add(t.Routes[i].Path, &t.Routes[i])
+ routeU(&t.Routes[i], logger)
+ add(t.Routes[i].Path, &t.Routes[i], logger)
}
}
-
- for i := 0; i < len(t.Routes); i++ {
- t.Routes[i].SwapSign(
- func(k string, b *Back) {
- logger.Info(`I:`, "后端加载", t.Routes[i].Path, b.Name)
- t.Routes[i].backMap.Store(k, b)
- },
- func(k string, b *Back) {
- logger.Info(`I:`, "后端移除", t.Routes[i].Path, b.Name)
- t.Routes[i].backMap.Delete(k)
- },
- logger,
- )
- }
}
type Route struct {
- Path string `json:"path"`
+ config *Config `json:"-"`
+ Path string `json:"path"`
Splicing int `json:"splicing"`
PathAdd bool `json:"pathAdd"`
func (t *Route) FiliterBackByRequest(r *http.Request) []*Back {
var backLink []*Back
for i := 0; i < len(t.Backs); i++ {
- if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].MatchHeader, r) {
+ if t.Backs[i].IsLive() && HeaderMatchs(t.Backs[i].ReqHeader, r) {
t.Backs[i].PathAdd = t.PathAdd
t.Backs[i].Splicing = t.Splicing
t.Backs[i].tmp.ReqHeader = append(t.ReqHeader, t.Backs[i].ReqHeader...)
}
type Back struct {
- lock sync.RWMutex `json:"-"`
- upT time.Time `json:"-"`
+ route *Route `json:"-"`
+ lock sync.RWMutex `json:"-"`
+ upT time.Time `json:"-"`
Name string `json:"name"`
To string `json:"to"`
}
type Matcher struct {
- MatchHeader []Header `json:"matchHeader"`
- ReqHeader []Header `json:"reqHeader"`
- ResHeader []Header `json:"resHeader"`
- ReqBody []Body `json:"reqBody"`
+ ReqHeader []Header `json:"reqHeader"`
+ ResHeader []Header `json:"resHeader"`
+ ReqBody []Body `json:"reqBody"`
}
type Header struct {
}
func (t *Header) Match(value string) bool {
- if t.Value != "" && value != t.Value {
- return false
+ if t.Action != "access" && t.Action != "deny" {
+ return true
}
if t.MatchExp != "" {
if exp, e := regexp.Compile(t.MatchExp); e != nil || !exp.MatchString(value) {
- return false
+ return t.Action == "deny"
}
}
- return true
+ return t.Action == "access"
}
type Body struct {
module github.com/qydysky/front
-go 1.21
+go 1.22
require (
+ github.com/dustin/go-humanize v1.0.1
github.com/gorilla/websocket v1.5.1
- github.com/qydysky/part v0.28.20240114140844
+ github.com/qydysky/part v0.28.20240309172046
golang.org/x/net v0.18.0
)
require (
+ github.com/andybalholm/brotli v1.0.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/google/uuid v1.4.0 // indirect
+ github.com/klauspost/compress v1.17.3 // indirect
+ github.com/miekg/dns v1.1.57 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
+ github.com/thedevsaddam/gojsonq/v2 v2.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
+ golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/text v0.14.0 // indirect
+ golang.org/x/tools v0.15.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
+github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
+github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/qydysky/part v0.28.20240114140844 h1:9d0NPLx2CHKFca5Q6brECXeCjj0VXXaXmZRcJo6W85w=
-github.com/qydysky/part v0.28.20240114140844/go.mod h1:NyKyjpBCSjcHtKlC+fL5lCidm57UCnwEgufiBDs5yxA=
+github.com/qydysky/part v0.28.20240309114649 h1:b82WHpgNecfv/UX4makU9EIHJ4nLAP9dnXCb8S+sJt4=
+github.com/qydysky/part v0.28.20240309114649/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c=
+github.com/qydysky/part v0.28.20240309172046 h1:aw2Dv8VaP0p+IMkwJQlCNaz0ccJ6l8YUhu+y39kvQgU=
+github.com/qydysky/part v0.28.20240309172046/go.mod h1:8Y4MrasGC0BLEM71QY/MuP2jl+v5b0Y+rqox3qJu97c=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/thedevsaddam/gojsonq/v2 v2.5.2 h1:CoMVaYyKFsVj6TjU6APqAhAvC07hTI6IQen8PHzHYY0=
+github.com/thedevsaddam/gojsonq/v2 v2.5.2/go.mod h1:bv6Xa7kWy82uT0LnXPE2SzGqTj33TAEeR560MdJkiXs=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
--- /dev/null
+package front
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+ _ "unsafe"
+
+ pslice "github.com/qydysky/part/slice"
+)
+
+func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+
+ var (
+ opT = time.Now()
+ resp *http.Response
+ chosenBack *Back
+ )
+
+ for 0 < len(backs) && resp == nil {
+ chosenBack = backs[0]
+ backs = backs[1:]
+
+ url := chosenBack.To
+ if chosenBack.PathAdd {
+ url += r.URL.String()
+ }
+
+ url = "http" + url
+
+ reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+ if e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return errors.Join(ErrBodyCheckFail, e)
+ }
+
+ req, e := http.NewRequestWithContext(ctx, r.Method, url, reader)
+ if e != nil {
+ return errors.Join(ErrReqCreFail, e)
+ }
+
+ if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return e
+ }
+
+ client := http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return ErrRedirect
+ },
+ }
+ resp, e = client.Do(req)
+ if e != nil && !errors.Is(e, ErrRedirect) && !errors.Is(e, context.Canceled) {
+ chosenBack.Disable()
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ }
+ }
+
+ if 0 == len(backs) && resp == nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+ return errors.New("全部后端故障")
+ } else if resp == nil {
+ return errors.New("后端故障")
+ }
+
+ logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v http ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+
+ {
+ cookie := &http.Cookie{
+ Name: "_psign_" + cookie,
+ Value: chosenBack.Id(),
+ MaxAge: chosenBack.Splicing,
+ Path: "/",
+ }
+ if validCookieDomain(r.Host) {
+ cookie.Domain = r.Host
+ }
+ w.Header().Add("Set-Cookie", (cookie).String())
+ }
+
+ w.Header().Add(header+"Info", cookie+";"+chosenBack.Name)
+
+ if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return e
+ }
+
+ w.WriteHeader(resp.StatusCode)
+
+ if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 {
+ return nil
+ }
+
+ defer resp.Body.Close()
+ if tmpbuf, put, e := blocksi.Get(); e != nil {
+ logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ chosenBack.Disable()
+ return errors.Join(ErrCopy, e)
+ } else {
+ defer put()
+ if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil {
+ logger.Error(`E:`, fmt.Sprintf("%v > %v > %v http %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ chosenBack.Disable()
+ return errors.Join(ErrCopy, e)
+ }
+ }
+ return nil
+}
package front
import (
- "bufio"
- "bytes"
"context"
- "crypto/tls"
+ "crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
- "log"
- "net"
"net/http"
- "net/http/httptrace"
- "net/url"
"regexp"
"strings"
"time"
_ "unsafe"
- "github.com/gorilla/websocket"
pctx "github.com/qydysky/part/ctx"
- pslice "github.com/qydysky/part/slice"
pweb "github.com/qydysky/part/web"
- "golang.org/x/net/proxy"
)
type Logger interface {
// 加载
func LoadPeriod(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) error {
- if e := loadConfig(buf, configF, configS); e != nil {
+ var oldBufMd5 string
+
+ if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil {
logger.Error(`E:`, "配置加载", e)
return e
+ } else {
+ oldBufMd5 = bufMd5
}
+
+ logger.Info(`I:`, "配置更新", oldBufMd5[:5])
+
// 定时加载config
go func() {
ctx1, done1 := pctx.WaitCtx(ctx)
for {
select {
case <-time.After(time.Second * 5):
- if e := loadConfig(buf, configF, configS); e != nil {
+ if bufMd5, e := loadConfig(ctx, buf, configF, configS, logger); e != nil {
logger.Error(`E:`, "配置加载", e)
+ } else if bufMd5 != oldBufMd5 {
+ oldBufMd5 = bufMd5
+ logger.Info(`I:`, "配置更新", oldBufMd5[:5])
}
case <-ctx1.Done():
return
<-ctx1.Done()
}
-var cookie = fmt.Sprintf("%p", &struct{}{})
-
-func loadConfig(buf []byte, configF File, configS *[]Config) error {
+func loadConfig(ctx context.Context, buf []byte, configF File, configS *[]Config, logger Logger) (md5k string, e error) {
+ defer func() {
+ if err := recover(); err != nil {
+ logger.Error(`E:`, err)
+ e = errors.New("read panic")
+ }
+ }()
if i, e := configF.Read(buf); e != nil && !errors.Is(e, io.EOF) {
- return e
+ return "", e
} else if i == cap(buf) {
- return errors.New(`buf full`)
+ return "", errors.New(`buf full`)
} else {
- for i := 0; i < len(*configS); i++ {
- (*configS)[i].lock.Lock()
- defer (*configS)[i].lock.Unlock()
- }
- if e := json.Unmarshal(buf[:i], configS); e != nil {
- return e
- }
- }
- return nil
-}
+ w := md5.New()
+ w.Write(buf[:i])
+ md5k = fmt.Sprintf("%x", w.Sum(nil))
-//go:linkname nanotime1 runtime.nanotime1
-func nanotime1() int64
-
-var (
- ErrRedirect = errors.New("ErrRedirect")
- ErrNoHttp = errors.New("ErrNoHttp")
- ErrNoWs = errors.New("ErrNoWs")
- ErrCopy = errors.New("ErrCopy")
- ErrReqCreFail = errors.New("ErrReqCreFail")
- ErrReqDoFail = errors.New("ErrReqDoFail")
- ErrResDoFail = errors.New("ErrResDoFail")
- ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
- ErrBodyCheckFail = errors.New("ErrBodyCheckFail")
-)
-
-func httpDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
- var (
- opT = time.Now()
- resp *http.Response
- chosenBack *Back
- )
-
- for 0 < len(backs) && resp == nil {
- chosenBack = backs[0]
- backs = backs[1:]
-
- url := chosenBack.To
- if chosenBack.PathAdd {
- url += r.URL.String()
- }
-
- url = "http" + url
-
- for _, v := range chosenBack.tmp.ReqHeader {
- if v.Action == `check` {
- if r.Header.Get(v.Key) != v.Value {
- return ErrHeaderCheckFail
- }
- }
- }
-
- reader, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
- if e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
- return errors.Join(ErrBodyCheckFail, e)
- }
-
- req, e := http.NewRequestWithContext(ctx, r.Method, url, reader)
- if e != nil {
- return errors.Join(ErrReqCreFail, e)
- }
-
- if e := copyHeader(r.Header, req.Header, chosenBack.tmp.ReqHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- return e
- }
-
- client := http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return ErrRedirect
- },
- }
- resp, e = client.Do(req)
- if e != nil && !errors.Is(e, ErrRedirect) {
- chosenBack.Disable()
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- }
- }
-
- if 0 == len(backs) && resp == nil {
- logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
- return errors.New("全部后端故障")
- } else if resp == nil {
- return errors.New("后端故障")
- }
-
- logger.Debug(`T:`, fmt.Sprintf("http %s=>%s %v", routePath, chosenBack.Name, time.Since(opT)))
-
- {
- cookie := &http.Cookie{
- Name: "_psign_" + cookie,
- Value: chosenBack.Id(),
- MaxAge: chosenBack.Splicing,
- Path: "/",
- }
- if validCookieDomain(r.Host) {
- cookie.Domain = r.Host
- }
- w.Header().Add("Set-Cookie", (cookie).String())
- }
-
- w.Header().Add("_pto_"+cookie, chosenBack.Name)
-
- if e := copyHeader(resp.Header, w.Header(), chosenBack.tmp.ResHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- return e
- }
-
- w.WriteHeader(resp.StatusCode)
-
- if resp.StatusCode < 200 || resp.StatusCode == 204 || resp.StatusCode == 304 {
- return nil
- }
-
- defer resp.Body.Close()
- if tmpbuf, put, e := blocksi.Get(); e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- chosenBack.Disable()
- return errors.Join(ErrCopy, e)
- } else {
- defer put()
- if _, e = io.CopyBuffer(w, resp.Body, tmpbuf); e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- chosenBack.Disable()
- return errors.Join(ErrCopy, e)
- }
- }
- return nil
-}
-
-func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
- var (
- opT = time.Now()
- resp *http.Response
- conn net.Conn
- chosenBack *Back
- )
-
- for 0 < len(backs) && (resp == nil || conn == nil) {
- chosenBack = backs[0]
- backs = backs[1:]
-
- _, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
- if e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, ErrBodyCheckFail))
- return errors.Join(ErrBodyCheckFail, e)
- }
-
- url := chosenBack.To
- if chosenBack.PathAdd {
- url += r.URL.String()
- }
-
- url = "ws" + url
-
- reqHeader := make(http.Header)
-
- if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- return e
- }
-
- conn, resp, e = DialContext(ctx, url, reqHeader)
- if e != nil {
- chosenBack.Disable()
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- }
- }
-
- if 0 == len(backs) && (resp == nil || conn == nil) {
- logger.Warn(`E:`, fmt.Sprintf("%s=>%s 全部后端故障", routePath, chosenBack.Name))
- return errors.New("全部后端故障")
- } else if resp == nil || conn == nil {
- return errors.New("后端故障")
- }
-
- logger.Debug(`T:`, fmt.Sprintf("ws %s=>%s %v", routePath, chosenBack.Name, time.Since(opT)))
-
- {
- cookie := &http.Cookie{
- Name: "_psign_" + cookie,
- Value: chosenBack.Id(),
- MaxAge: chosenBack.Splicing,
- Path: "/",
- }
- if validCookieDomain(r.Host) {
- cookie.Domain = r.Host
+ if e := json.Unmarshal(buf[:i], configS); e != nil {
+ return md5k, e
}
- w.Header().Add("Set-Cookie", (cookie).String())
- }
-
- w.Header().Add("_pto_"+cookie, chosenBack.Name)
-
- defer conn.Close()
-
- resHeader := make(http.Header)
- if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil {
- logger.Warn(`W:`, fmt.Sprintf("%s=>%s %v", routePath, chosenBack.Name, e))
- return e
- }
-
- if req, e := Upgrade(w, r, resHeader); e != nil {
- return errors.Join(ErrResDoFail, e)
- } else {
- defer req.Close()
-
- select {
- case e := <-copyWsMsg(req, conn, blocksi):
- if e != nil {
- chosenBack.Disable()
- logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, chosenBack.Name, e))
- return errors.Join(ErrCopy, e)
- }
- case e := <-copyWsMsg(conn, req, blocksi):
- if e != nil {
- chosenBack.Disable()
- logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, chosenBack.Name, e))
- return errors.Join(ErrCopy, e)
- }
- case <-ctx.Done():
+ for i := 0; i < len(*configS); i++ {
+ (*configS)[i].lock.Lock()
+ (*configS)[i].SwapSign(ctx, logger)
+ (*configS)[i].lock.Unlock()
}
-
- return nil
}
+ return md5k, nil
}
func copyHeader(s, t http.Header, app []Header) error {
}
for _, v := range app {
switch v.Action {
- case `check`:
- if !v.Match(tm[v.Key][0]) {
+ case `deny`:
+ if va, ok := tm[v.Key]; ok && len(va) != 0 {
+ if exp, e := regexp.Compile(v.MatchExp); e == nil && exp.MatchString(va[0]) {
+ return ErrHeaderCheckFail
+ }
+ }
+ case `access`:
+ if va, ok := tm[v.Key]; ok && len(va) != 0 {
+ if exp, e := regexp.Compile(v.MatchExp); e != nil || !exp.MatchString(va[0]) {
+ return ErrHeaderCheckFail
+ }
+ } else {
return ErrHeaderCheckFail
}
case `replace`:
return nil
}
-func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error {
- c := make(chan error, 1)
- go func() {
- if tmpbuf, put, e := blocksi.Get(); e != nil {
- c <- e
- } else {
- defer put()
- _, e := io.CopyBuffer(dst, src, tmpbuf)
- c <- e
- }
- }()
- return c
-}
-
-func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) {
- d := websocket.DefaultDialer
-
- challengeKey := requestHeader.Get("Sec-WebSocket-Key")
-
- u, err := url.Parse(urlStr)
- if err != nil {
- return nil, nil, err
- }
-
- switch u.Scheme {
- case "ws":
- u.Scheme = "http"
- case "wss":
- u.Scheme = "https"
- default:
- return nil, nil, errMalformedURL
- }
-
- if u.User != nil {
- // User name and password are not allowed in websocket URIs.
- return nil, nil, errMalformedURL
- }
-
- req := &http.Request{
- Method: http.MethodGet,
- URL: u,
- Proto: "HTTP/1.1",
- ProtoMajor: 1,
- ProtoMinor: 1,
- Header: make(http.Header),
- Host: u.Host,
- }
- req = req.WithContext(ctx)
-
- // Set the request headers using the capitalization for names and values in
- // RFC examples. Although the capitalization shouldn't matter, there are
- // servers that depend on it. The Header.Set method is not used because the
- // method canonicalizes the header names.
- for k, vs := range requestHeader {
- req.Header[k] = vs
- }
-
- if d.HandshakeTimeout != 0 {
- var cancel func()
- ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
- defer cancel()
- }
-
- // Get network dial function.
- var netDial func(network, add string) (net.Conn, error)
-
- switch u.Scheme {
- case "http":
- if d.NetDialContext != nil {
- netDial = func(network, addr string) (net.Conn, error) {
- return d.NetDialContext(ctx, network, addr)
- }
- } else if d.NetDial != nil {
- netDial = d.NetDial
- }
- case "https":
- if d.NetDialTLSContext != nil {
- netDial = func(network, addr string) (net.Conn, error) {
- return d.NetDialTLSContext(ctx, network, addr)
- }
- } else if d.NetDialContext != nil {
- netDial = func(network, addr string) (net.Conn, error) {
- return d.NetDialContext(ctx, network, addr)
- }
- } else if d.NetDial != nil {
- netDial = d.NetDial
- }
- default:
- return nil, nil, errMalformedURL
- }
-
- if netDial == nil {
- netDialer := &net.Dialer{}
- netDial = func(network, addr string) (net.Conn, error) {
- return netDialer.DialContext(ctx, network, addr)
- }
- }
-
- // If needed, wrap the dial function to set the connection deadline.
- if deadline, ok := ctx.Deadline(); ok {
- forwardDial := netDial
- netDial = func(network, addr string) (net.Conn, error) {
- c, err := forwardDial(network, addr)
- if err != nil {
- return nil, err
- }
- err = c.SetDeadline(deadline)
- if err != nil {
- if err := c.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, err
- }
- return c, nil
- }
- }
-
- // If needed, wrap the dial function to connect through a proxy.
- if d.Proxy != nil {
- proxyURL, err := d.Proxy(req)
- if err != nil {
- return nil, nil, err
- }
- if proxyURL != nil {
- dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
- if err != nil {
- return nil, nil, err
- }
- netDial = dialer.Dial
- }
- }
-
- hostPort, hostNoPort := hostPortNoPort(u)
- trace := httptrace.ContextClientTrace(ctx)
- if trace != nil && trace.GetConn != nil {
- trace.GetConn(hostPort)
- }
-
- netConn, err := netDial("tcp", hostPort)
- if err != nil {
- return nil, nil, err
- }
- if trace != nil && trace.GotConn != nil {
- trace.GotConn(httptrace.GotConnInfo{
- Conn: netConn,
- })
- }
-
- if u.Scheme == "https" && d.NetDialTLSContext == nil {
- // If NetDialTLSContext is set, assume that the TLS handshake has already been done
-
- cfg := cloneTLSConfig(d.TLSClientConfig)
- if cfg.ServerName == "" {
- cfg.ServerName = hostNoPort
- }
- tlsConn := tls.Client(netConn, cfg)
- netConn = tlsConn
-
- if trace != nil && trace.TLSHandshakeStart != nil {
- trace.TLSHandshakeStart()
- }
- err := doHandshake(ctx, tlsConn, cfg)
- if trace != nil && trace.TLSHandshakeDone != nil {
- trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
- }
-
- if err != nil {
- return nil, nil, err
- }
- }
-
- var br *bufio.Reader
- if br == nil {
- if d.ReadBufferSize == 0 {
- d.ReadBufferSize = defaultReadBufferSize
- } else if d.ReadBufferSize < maxControlFramePayloadSize {
- // must be large enough for control frame
- d.ReadBufferSize = maxControlFramePayloadSize
- }
- br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
- }
-
- if err := req.Write(netConn); err != nil {
- return nil, nil, err
- }
-
- if trace != nil && trace.GotFirstResponseByte != nil {
- if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
- trace.GotFirstResponseByte()
- }
- }
-
- resp, err := http.ReadResponse(br, req)
- if err != nil {
- if d.TLSClientConfig != nil {
- for _, proto := range d.TLSClientConfig.NextProtos {
- if proto != "http/1.1" {
- return nil, nil, fmt.Errorf(
- "websocket: protocol %q was given but is not supported;"+
- "sharing tls.Config with net/http Transport can cause this error: %w",
- proto, err,
- )
- }
- }
- }
- return nil, nil, err
- }
-
- if d.Jar != nil {
- if rc := resp.Cookies(); len(rc) > 0 {
- d.Jar.SetCookies(u, rc)
- }
- }
-
- if resp.StatusCode != 101 ||
- !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
- !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
- resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
- // Before closing the network connection on return from this
- // function, slurp up some of the response to aid application
- // debugging.
- buf := make([]byte, 1024)
- n, _ := io.ReadFull(resp.Body, buf)
- resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
- log.Default().Println(resp.StatusCode, resp.Header)
- return nil, resp, websocket.ErrBadHandshake
- }
-
- resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
-
- if err := netConn.SetDeadline(time.Time{}); err != nil {
- return nil, nil, err
- }
- return netConn, resp, nil
-}
-
-type netDialerFunc func(network, addr string) (net.Conn, error)
-
-func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
- return fn(network, addr)
-}
-
-//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
-func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
-
-//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
-func cloneTLSConfig(cfg *tls.Config) *tls.Config
-
-//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
-func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
-
-//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
-var errMalformedURL error
-
-//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
-var errInvalidCompression error
-
-//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
-func generateChallengeKey() (string, error)
-
-//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
-func tokenListContainsValue(header http.Header, name string, value string) bool
-
-//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
-// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
-
-//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
-func checkSameOrigin(r *http.Request) bool
-
-//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
-func isValidChallengeKey(s string) bool
-
-//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
-func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
-
-//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
-func parseExtensions(header http.Header) []map[string]string
-
-//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
-func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
-
-//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
-func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
-
-//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
-func computeAcceptKey(challengeKey string) string
-
-const (
- maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
- defaultReadBufferSize = 4096
- defaultWriteBufferSize = 4096
- maxControlFramePayloadSize = 125
+var cookie = fmt.Sprintf("%p", &struct{}{})
+var header = "X-Front-"
+var (
+ ErrRedirect = errors.New("ErrRedirect")
+ ErrNoHttp = errors.New("ErrNoHttp")
+ ErrNoWs = errors.New("ErrNoWs")
+ ErrCopy = errors.New("ErrCopy")
+ ErrReqCreFail = errors.New("ErrReqCreFail")
+ ErrReqDoFail = errors.New("ErrReqDoFail")
+ ErrResDoFail = errors.New("ErrResDoFail")
+ ErrHeaderCheckFail = errors.New("ErrHeaderCheckFail")
+ ErrBodyCheckFail = errors.New("ErrBodyCheckFail")
)
-
-func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
- u := &websocket.Upgrader{}
- h, ok := w.(http.Hijacker)
- if !ok {
- return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
- }
- var brw *bufio.ReadWriter
- netConn, brw, err := h.Hijack()
- if err != nil {
- return returnError(u, w, r, http.StatusInternalServerError, err.Error())
- }
-
- if brw.Reader.Buffered() > 0 {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, errors.New("websocket: client sent data before handshake is complete")
- }
-
- buf := bufioWriterBuffer(netConn, brw.Writer)
-
- var writeBuf []byte
- if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
- // Reuse hijacked write buffer as connection buffer.
- writeBuf = buf
- } else {
- if u.WriteBufferSize <= 0 {
- u.WriteBufferSize = defaultWriteBufferSize
- }
- u.WriteBufferSize += maxFrameHeaderSize
- if writeBuf == nil && u.WriteBufferPool == nil {
- writeBuf = make([]byte, u.WriteBufferSize)
- }
- }
-
- // Use larger of hijacked buffer and connection write buffer for header.
- p := buf
- if len(writeBuf) > len(p) {
- p = writeBuf
- }
- p = p[:0]
-
- p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
- for k, vs := range responseHeader {
- for _, v := range vs {
- p = append(p, k...)
- p = append(p, ": "...)
- for i := 0; i < len(v); i++ {
- b := v[i]
- if b <= 31 {
- // prevent response splitting.
- b = ' '
- }
- p = append(p, b)
- }
- p = append(p, "\r\n"...)
- }
- }
- p = append(p, "\r\n"...)
-
- // Clear deadlines set by HTTP server.
- if err := netConn.SetDeadline(time.Time{}); err != nil {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, err
- }
-
- if u.HandshakeTimeout > 0 {
- if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, err
- }
- }
- if _, err = netConn.Write(p); err != nil {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, err
- }
- if u.HandshakeTimeout > 0 {
- if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- return nil, err
- }
- }
-
- return netConn, nil
-}
-
-func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) {
- err := HandshakeError{message: reason}
- if u.Error != nil {
- u.Error(w, r, status, err)
- } else {
- w.Header().Set("Sec-Websocket-Version", "13")
- http.Error(w, http.StatusText(status), status)
- }
- return nil, err
-}
-
-type HandshakeError struct {
- message string
-}
-
-func (t HandshakeError) Error() string {
- return t.message
-}
go pfront.Test(ctx, *testP, logger.Base("测试"))
for i := 0; i < len(configS); i++ {
- go configS[i].Run(ctx, logger.Base(configS[i].Addr))
+ go configS[i].Run(ctx, logger)
}
// ctrl+c退出
--- /dev/null
+package front
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptrace"
+ "net/url"
+ "time"
+ _ "unsafe"
+
+ "github.com/gorilla/websocket"
+ pslice "github.com/qydysky/part/slice"
+ "golang.org/x/net/proxy"
+)
+
+func wsDealer(ctx context.Context, w http.ResponseWriter, r *http.Request, routePath string, backs []*Back, logger Logger, blocksi pslice.BlocksI[byte]) error {
+ var (
+ opT = time.Now()
+ resp *http.Response
+ conn net.Conn
+ chosenBack *Back
+ )
+
+ for 0 < len(backs) && (resp == nil || conn == nil) {
+ chosenBack = backs[0]
+ backs = backs[1:]
+
+ _, e := BodyMatchs(chosenBack.tmp.ReqBody, r)
+ if e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return errors.Join(ErrBodyCheckFail, e)
+ }
+
+ url := chosenBack.To
+ if chosenBack.PathAdd {
+ url += r.URL.String()
+ }
+
+ url = "ws" + url
+
+ reqHeader := make(http.Header)
+
+ if e := copyHeader(r.Header, reqHeader, chosenBack.tmp.ReqHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return e
+ }
+
+ conn, resp, e = DialContext(ctx, url, reqHeader)
+ if e != nil && !errors.Is(e, context.Canceled) {
+ chosenBack.Disable()
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ }
+ }
+
+ if 0 == len(backs) && (resp == nil || conn == nil) {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws 全部后端故障 %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+ return errors.New("全部后端故障")
+ } else if resp == nil || conn == nil {
+ return errors.New("后端故障")
+ }
+
+ logger.Debug(`T:`, fmt.Sprintf("%v > %v > %v ws ok %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, time.Since(opT)))
+
+ {
+ cookie := &http.Cookie{
+ Name: "_psign_" + cookie,
+ Value: chosenBack.Id(),
+ MaxAge: chosenBack.Splicing,
+ Path: "/",
+ }
+ if validCookieDomain(r.Host) {
+ cookie.Domain = r.Host
+ }
+ w.Header().Add("Set-Cookie", (cookie).String())
+ }
+
+ w.Header().Add(header+"Info", cookie+";"+chosenBack.Name)
+
+ defer conn.Close()
+
+ resHeader := make(http.Header)
+ if e := copyHeader(resp.Header, resHeader, chosenBack.tmp.ResHeader); e != nil {
+ logger.Warn(`W:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return e
+ }
+
+ if req, e := Upgrade(w, r, resHeader); e != nil {
+ return errors.Join(ErrResDoFail, e)
+ } else {
+ defer req.Close()
+
+ select {
+ case e := <-copyWsMsg(req, conn, blocksi):
+ if e != nil {
+ chosenBack.Disable()
+ logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return errors.Join(ErrCopy, e)
+ }
+ case e := <-copyWsMsg(conn, req, blocksi):
+ if e != nil {
+ chosenBack.Disable()
+ logger.Error(`E:`, fmt.Sprintf("%v > %v > %v ws %v %v", chosenBack.route.config.Addr, routePath, chosenBack.Name, e, time.Since(opT)))
+ return errors.Join(ErrCopy, e)
+ }
+ case <-ctx.Done():
+ }
+
+ return nil
+ }
+}
+
+func copyWsMsg(dst io.Writer, src io.Reader, blocksi pslice.BlocksI[byte]) <-chan error {
+ c := make(chan error, 1)
+ go func() {
+ if tmpbuf, put, e := blocksi.Get(); e != nil {
+ c <- e
+ } else {
+ defer put()
+ _, e := io.CopyBuffer(dst, src, tmpbuf)
+ c <- e
+ }
+ }()
+ return c
+}
+
+func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) {
+ d := websocket.DefaultDialer
+
+ challengeKey := requestHeader.Get("Sec-WebSocket-Key")
+
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ switch u.Scheme {
+ case "ws":
+ u.Scheme = "http"
+ case "wss":
+ u.Scheme = "https"
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if u.User != nil {
+ // User name and password are not allowed in websocket URIs.
+ return nil, nil, errMalformedURL
+ }
+
+ req := &http.Request{
+ Method: http.MethodGet,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: u.Host,
+ }
+ req = req.WithContext(ctx)
+
+ // Set the request headers using the capitalization for names and values in
+ // RFC examples. Although the capitalization shouldn't matter, there are
+ // servers that depend on it. The Header.Set method is not used because the
+ // method canonicalizes the header names.
+ for k, vs := range requestHeader {
+ req.Header[k] = vs
+ }
+
+ if d.HandshakeTimeout != 0 {
+ var cancel func()
+ ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+ defer cancel()
+ }
+
+ // Get network dial function.
+ var netDial func(network, add string) (net.Conn, error)
+
+ switch u.Scheme {
+ case "http":
+ if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ case "https":
+ if d.NetDialTLSContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialTLSContext(ctx, network, addr)
+ }
+ } else if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if netDial == nil {
+ netDialer := &net.Dialer{}
+ netDial = func(network, addr string) (net.Conn, error) {
+ return netDialer.DialContext(ctx, network, addr)
+ }
+ }
+
+ // If needed, wrap the dial function to set the connection deadline.
+ if deadline, ok := ctx.Deadline(); ok {
+ forwardDial := netDial
+ netDial = func(network, addr string) (net.Conn, error) {
+ c, err := forwardDial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ err = c.SetDeadline(deadline)
+ if err != nil {
+ if err := c.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ return c, nil
+ }
+ }
+
+ // If needed, wrap the dial function to connect through a proxy.
+ if d.Proxy != nil {
+ proxyURL, err := d.Proxy(req)
+ if err != nil {
+ return nil, nil, err
+ }
+ if proxyURL != nil {
+ dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
+ if err != nil {
+ return nil, nil, err
+ }
+ netDial = dialer.Dial
+ }
+ }
+
+ hostPort, hostNoPort := hostPortNoPort(u)
+ trace := httptrace.ContextClientTrace(ctx)
+ if trace != nil && trace.GetConn != nil {
+ trace.GetConn(hostPort)
+ }
+
+ netConn, err := netDial("tcp", hostPort)
+ if err != nil {
+ return nil, nil, err
+ }
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{
+ Conn: netConn,
+ })
+ }
+
+ if u.Scheme == "https" && d.NetDialTLSContext == nil {
+ // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
+ cfg := cloneTLSConfig(d.TLSClientConfig)
+ if cfg.ServerName == "" {
+ cfg.ServerName = hostNoPort
+ }
+ tlsConn := tls.Client(netConn, cfg)
+ netConn = tlsConn
+
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := doHandshake(ctx, tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+ }
+
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ var br *bufio.Reader
+ if br == nil {
+ if d.ReadBufferSize == 0 {
+ d.ReadBufferSize = defaultReadBufferSize
+ } else if d.ReadBufferSize < maxControlFramePayloadSize {
+ // must be large enough for control frame
+ d.ReadBufferSize = maxControlFramePayloadSize
+ }
+ br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
+ }
+
+ if err := req.Write(netConn); err != nil {
+ return nil, nil, err
+ }
+
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ if peek, err := br.Peek(1); err == nil && len(peek) == 1 {
+ trace.GotFirstResponseByte()
+ }
+ }
+
+ resp, err := http.ReadResponse(br, req)
+ if err != nil {
+ if d.TLSClientConfig != nil {
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto != "http/1.1" {
+ return nil, nil, fmt.Errorf(
+ "websocket: protocol %q was given but is not supported;"+
+ "sharing tls.Config with net/http Transport can cause this error: %w",
+ proto, err,
+ )
+ }
+ }
+ }
+ return nil, nil, err
+ }
+
+ if d.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ d.Jar.SetCookies(u, rc)
+ }
+ }
+
+ if resp.StatusCode != 101 ||
+ !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+ !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+ resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+ // Before closing the network connection on return from this
+ // function, slurp up some of the response to aid application
+ // debugging.
+ buf := make([]byte, 1024)
+ n, _ := io.ReadFull(resp.Body, buf)
+ resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+ log.Default().Println(resp.StatusCode, resp.Header)
+ return nil, resp, websocket.ErrBadHandshake
+ }
+
+ resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
+
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ return nil, nil, err
+ }
+ return netConn, resp, nil
+}
+
+type netDialerFunc func(network, addr string) (net.Conn, error)
+
+func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
+ return fn(network, addr)
+}
+
+//go:linkname doHandshake github.com/gorilla/websocket.doHandshake
+func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error
+
+//go:linkname cloneTLSConfig github.com/gorilla/websocket.cloneTLSConfig
+func cloneTLSConfig(cfg *tls.Config) *tls.Config
+
+//go:linkname hostPortNoPort github.com/gorilla/websocket.hostPortNoPort
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string)
+
+//go:linkname errMalformedURL github.com/gorilla/websocket.errMalformedURL
+var errMalformedURL error
+
+//go:linkname errInvalidCompression github.com/gorilla/websocket.errInvalidCompression
+var errInvalidCompression error
+
+//go:linkname generateChallengeKey github.com/gorilla/websocket.generateChallengeKey
+func generateChallengeKey() (string, error)
+
+//go:linkname tokenListContainsValue github.com/gorilla/websocket.tokenListContainsValue
+func tokenListContainsValue(header http.Header, name string, value string) bool
+
+//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
+// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
+
+//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
+func checkSameOrigin(r *http.Request) bool
+
+//go:linkname isValidChallengeKey github.com/gorilla/websocket.isValidChallengeKey
+func isValidChallengeKey(s string) bool
+
+//go:linkname selectSubprotocol github.com/gorilla/websocket.(*Upgrader).selectSubprotocol
+func selectSubprotocol(u *websocket.Upgrader, r *http.Request, responseHeader http.Header) string
+
+//go:linkname parseExtensions github.com/gorilla/websocket.parseExtensions
+func parseExtensions(header http.Header) []map[string]string
+
+//go:linkname bufioReaderSize github.com/gorilla/websocket.bufioReaderSize
+func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int
+
+//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
+func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
+
+//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
+func computeAcceptKey(challengeKey string) string
+
+const (
+ maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
+ defaultReadBufferSize = 4096
+ defaultWriteBufferSize = 4096
+ maxControlFramePayloadSize = 125
+)
+
+func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
+ u := &websocket.Upgrader{}
+ h, ok := w.(http.Hijacker)
+ if !ok {
+ return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
+ }
+ var brw *bufio.ReadWriter
+ netConn, brw, err := h.Hijack()
+ if err != nil {
+ return returnError(u, w, r, http.StatusInternalServerError, err.Error())
+ }
+
+ if brw.Reader.Buffered() > 0 {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, errors.New("websocket: client sent data before handshake is complete")
+ }
+
+ buf := bufioWriterBuffer(netConn, brw.Writer)
+
+ var writeBuf []byte
+ if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
+ // Reuse hijacked write buffer as connection buffer.
+ writeBuf = buf
+ } else {
+ if u.WriteBufferSize <= 0 {
+ u.WriteBufferSize = defaultWriteBufferSize
+ }
+ u.WriteBufferSize += maxFrameHeaderSize
+ if writeBuf == nil && u.WriteBufferPool == nil {
+ writeBuf = make([]byte, u.WriteBufferSize)
+ }
+ }
+
+ // Use larger of hijacked buffer and connection write buffer for header.
+ p := buf
+ if len(writeBuf) > len(p) {
+ p = writeBuf
+ }
+ p = p[:0]
+
+ p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
+ for k, vs := range responseHeader {
+ for _, v := range vs {
+ p = append(p, k...)
+ p = append(p, ": "...)
+ for i := 0; i < len(v); i++ {
+ b := v[i]
+ if b <= 31 {
+ // prevent response splitting.
+ b = ' '
+ }
+ p = append(p, b)
+ }
+ p = append(p, "\r\n"...)
+ }
+ }
+ p = append(p, "\r\n"...)
+
+ // Clear deadlines set by HTTP server.
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ }
+ if _, err = netConn.Write(p); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
+ if err := netConn.Close(); err != nil {
+ log.Printf("websocket: failed to close network connection: %v", err)
+ }
+ return nil, err
+ }
+ }
+
+ return netConn, nil
+}
+
+func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) {
+ err := HandshakeError{message: reason}
+ if u.Error != nil {
+ u.Error(w, r, status, err)
+ } else {
+ w.Header().Set("Sec-Websocket-Version", "13")
+ http.Error(w, http.StatusText(status), status)
+ }
+ return nil, err
+}
+
+type HandshakeError struct {
+ message string
+}
+
+func (t HandshakeError) Error() string {
+ return t.message
+}