"github.com/gorilla/websocket"
pctx "github.com/qydysky/part/ctx"
+ pslice "github.com/qydysky/part/slice"
pweb "github.com/qydysky/part/web"
"golang.org/x/net/proxy"
)
return e
}
- fmt.Println(reqHeader.Get("Sec-WebSocket-Key"))
-
- if res, resp, e := DialContext(websocket.DefaultDialer, context.Background(), url, reqHeader); e != nil {
+ if res, resp, e := DialContext(ctx, url, reqHeader); e != nil {
return errors.Join(ErrReqDoFail, e)
} else {
defer res.Close()
return e
}
- fmt.Println(resHeader.Get("Sec-Websocket-Accept"))
-
- if req, e := Upgrade(&websocket.Upgrader{}, w, r, resHeader); e != nil {
+ if req, e := Upgrade(w, r, resHeader); e != nil {
return errors.Join(ErrResDoFail, e)
} else {
defer req.Close()
- fin := make(chan error, 2)
- reqc := req.NetConn()
- resc := res.NetConn()
- go func() {
- ctx1, done1 := pctx.WaitCtx(ctx)
- defer done1()
- select {
- case fin <- copyWsMsg(reqc, resc):
- case <-ctx1.Done():
- fin <- nil
+ select {
+ case e := <-copyWsMsg(req, res):
+ if e != nil {
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s s->c %v", routePath, back.Name, e))
+ return errors.Join(ErrCopy, e)
}
- }()
- go func() {
- ctx1, done1 := pctx.WaitCtx(ctx)
- defer done1()
- select {
- case fin <- copyWsMsg(resc, reqc):
- case <-ctx1.Done():
- fin <- nil
+ case e := <-copyWsMsg(res, req):
+ if e != nil {
+ logger.Error(`E:`, fmt.Sprintf("%s=>%s c->s %v", routePath, back.Name, e))
+ return errors.Join(ErrCopy, e)
}
- }()
- if e := <-fin; e != nil {
- logger.Error(`E:`, fmt.Sprintf("%s=>%s %v", routePath, back.Name, e))
- return errors.Join(ErrCopy, e)
+ case <-ctx.Done():
}
+
return nil
}
}
return nil
}
-func copyWsMsg(dst io.Writer, src io.Reader) (e error) {
- _, e = io.Copy(dst, src)
- return
+var copyBuf = pslice.NewBlocks[byte](16*1024, 100)
+
+func copyWsMsg(dst io.Writer, src io.Reader) <-chan error {
+ c := make(chan error, 1)
+ go func() {
+ if tmpbuf, put, e := copyBuf.Get(); e != nil {
+ c <- e
+ } else {
+ _, e := io.CopyBuffer(dst, src, tmpbuf)
+ put()
+ c <- e
+ }
+ }()
+ return c
}
-func DialContext(d *websocket.Dialer, ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) {
- if d == nil {
- d = websocket.DefaultDialer
- }
+func DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (net.Conn, *http.Response, error) {
+ d := websocket.DefaultDialer
challengeKey := requestHeader.Get("Sec-WebSocket-Key")
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
for k, vs := range requestHeader {
- switch {
- case k == "Host":
- if len(vs) > 0 {
- req.Host = vs[0]
- }
- case k == "Sec-Websocket-Protocol":
- return nil, nil, errors.New("websocket: header not allowed: " + k)
- default:
- req.Header[k] = vs
- }
+ req.Header[k] = vs
}
if d.HandshakeTimeout != 0 {
})
}
- defer func() {
- if netConn != nil {
- if err := netConn.Close(); err != nil {
- log.Printf("websocket: failed to close network connection: %v", err)
- }
- }
- }()
-
if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
br = bufio.NewReaderSize(netConn, d.ReadBufferSize)
}
- conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, br, nil)
-
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, nil, err
}
- netConn = nil // to avoid close in defer.
- return conn, resp, nil
+ return netConn, resp, nil
}
type netDialerFunc func(network, addr string) (net.Conn, error)
func tokenListContainsValue(header http.Header, name string, value string) bool
//go:linkname returnError github.com/gorilla/websocket.(*Upgrader).returnError
-func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
+// func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (*websocket.Conn, error)
//go:linkname checkSameOrigin github.com/gorilla/websocket.checkSameOrigin
func checkSameOrigin(r *http.Request) bool
//go:linkname bufioWriterBuffer github.com/gorilla/websocket.bufioWriterBuffer
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte
-//go:linkname newConn github.com/gorilla/websocket.newConn
-func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool websocket.BufferPool, br *bufio.Reader, writeBuf []byte) *websocket.Conn
-
//go:linkname computeAcceptKey github.com/gorilla/websocket.computeAcceptKey
func computeAcceptKey(challengeKey string) string
maxControlFramePayloadSize = 125
)
-func Upgrade(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) {
- const badHandshake = "websocket: the client is not using the websocket protocol: "
-
- if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
- return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
- }
-
- if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
- return returnError(u, w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
- }
-
- if r.Method != http.MethodGet {
- return returnError(u, w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
- }
-
- if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
- return returnError(u, w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
- }
-
- checkOrigin := u.CheckOrigin
- if checkOrigin == nil {
- checkOrigin = checkSameOrigin
- }
- if !checkOrigin(r) {
- return returnError(u, w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
- }
-
- challengeKey := r.Header.Get("Sec-Websocket-Key")
- if !isValidChallengeKey(challengeKey) {
- return returnError(u, w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
- }
-
+func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (net.Conn, error) {
+ u := &websocket.Upgrader{}
h, ok := w.(http.Hijacker)
if !ok {
return returnError(u, w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
return nil, errors.New("websocket: client sent data before handshake is complete")
}
- var br *bufio.Reader
- if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
- // Reuse hijacked buffered reader as connection reader.
- br = brw.Reader
- }
-
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
}
}
- c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
-
// Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(writeBuf) > len(p) {
p = append(p, "HTTP/1.1 101 Switching Protocols\r\n"...)
for k, vs := range responseHeader {
- if k == "Sec-Websocket-Protocol" || k == "Sec-Websocket-Extensions" {
- continue
- }
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
}
}
- return c, nil
+ return netConn, nil
+}
+
+func returnError(u *websocket.Upgrader, w http.ResponseWriter, r *http.Request, status int, reason string) (net.Conn, error) {
+ err := HandshakeError{message: reason}
+ if u.Error != nil {
+ u.Error(w, r, status, err)
+ } else {
+ w.Header().Set("Sec-Websocket-Version", "13")
+ http.Error(w, http.StatusText(status), status)
+ }
+ return nil, err
+}
+
+type HandshakeError struct {
+ message string
+}
+
+func (t HandshakeError) Error() string {
+ return t.message
}