]> 127.0.0.1 Git - part/.git/commitdiff
1 (#45) v0.28.20250420054637
authorqydysky <qydysky@foxmail.com>
Sun, 20 Apr 2025 05:46:29 +0000 (13:46 +0800)
committerGitHub <noreply@github.com>
Sun, 20 Apr 2025 05:46:29 +0000 (13:46 +0800)
Net.go

diff --git a/Net.go b/Net.go
index 901e4684b635dd7ef621186268e8dc5ec58ec346..721667e9e8cdfe01aa95942385b5fee11abaf6a6 100644 (file)
--- a/Net.go
+++ b/Net.go
@@ -2,25 +2,29 @@ package part
 
 import (
        "bytes"
+       "context"
        "errors"
        "fmt"
        "io"
        "net"
        "net/url"
+       "os"
        "os/exec"
        "runtime"
        "strings"
-       "sync"
        "time"
 
+       "github.com/dustin/go-humanize"
        "github.com/miekg/dns"
+       pe "github.com/qydysky/part/errors"
        pfile "github.com/qydysky/part/file"
+       psync "github.com/qydysky/part/sync"
 )
 
 var (
        ErrDnsNoAnswer      = errors.New("ErrDnsNoAnswer")
        ErrNetworkNoSupport = errors.New("ErrNetworkNoSupport")
-       ErrUdpOverflow      = errors.New("ErrUdpOverflow") // default:1500 set higher pkgSize at NewUdpListener
+       ErrUdpOverflow      = errors.New("ErrUdpOverflow")
 )
 
 type netl struct {
@@ -70,80 +74,81 @@ func (*netl) TestDial(network, address string, Timeout int) bool {
        return true
 }
 
-const (
-       ErrorMsg = iota
-       WarnMsg
-       AcceptMsg
-       DenyMsg
-       LisnMsg
-       ClosMsg
-)
-
-// when Type is ErrorMsg, Msg is set to error
-// when Type is AcceptMsg, Msg is set to net.Addr
-// when Type is DenyMsg, Msg is set to net.Addr
-// when Type is LisnMsg, Msg is set to net.Addr
-type ForwardMsg struct {
-       Type int
-       Msg  interface{}
-}
-
 // 桥
 func connBridge(a, b net.Conn, bufSize int) {
-       fmt.Println(b.LocalAddr())
-       var wg sync.WaitGroup
+       var wg = make(chan struct{}, 3)
 
-       wg.Add(2)
-       go func() {
-               defer func() {
-                       fmt.Println("close r", b.LocalAddr())
-                       wg.Done()
-               }()
-
-               buf := make([]byte, bufSize)
+       buf := make([]byte, bufSize*2)
 
+       go func() {
+               buf := buf[:bufSize]
                for {
                        if n, err := a.Read(buf); err != nil {
-                               a.Close()
-                               return
+                               if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrDeadlineExceeded) && !errors.Is(err, net.ErrClosed) {
+                                       fmt.Println(err)
+                               }
+                               break
                        } else if _, err = b.Write(buf[:n]); err != nil {
-                               fmt.Println(err)
-                               return
+                               if !errors.Is(err, io.EOF) {
+                                       fmt.Println(err)
+                               }
+                               break
                        }
                }
+               wg <- struct{}{}
        }()
 
        go func() {
-               defer func() {
-                       fmt.Println("close w", b.LocalAddr())
-                       wg.Done()
-               }()
-
-               buf := make([]byte, bufSize)
-
+               buf := buf[bufSize:]
                for {
                        if n, err := b.Read(buf); err != nil {
-                               b.Close()
-                               return
+                               if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrDeadlineExceeded) && !errors.Is(err, net.ErrClosed) {
+                                       fmt.Println(err)
+                               }
+                               break
                        } else if _, err = a.Write(buf[:n]); err != nil {
-                               fmt.Println(err)
-                               return
+                               if !errors.Is(err, io.EOF) {
+                                       fmt.Println(err)
+                               }
+                               break
                        }
                }
+               wg <- struct{}{}
        }()
 
-       wg.Wait()
-       fmt.Println("close", b.LocalAddr())
+       <-wg
+       a.Close()
+       b.Close()
 }
 
-func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
-       msg_chan = make(chan ForwardMsg, 1000)
+var (
+       ErrForwardAccept pe.Action = `ErrForwardAccept`
+       ErrForwardDail   pe.Action = `ErrForwardDail`
+)
+
+type ForwardMsgFunc interface {
+       ErrorMsg(targetaddr, listenaddr string, e error)
+       WarnMsg(targetaddr, listenaddr string, e error)
+       AcceptMsg(remote net.Addr, targetaddr string) (ConFinMsg func())
+       DenyMsg(remote net.Addr, targetaddr string)
+       LisnMsg(targetaddr, listenaddr string)
+       ClosMsg(targetaddr, listenaddr string)
+}
+
+func Forward(targetaddr, listenaddr string, acceptCIDRs []string, callBack ForwardMsgFunc) (closef func()) {
        closef = func() {}
 
        lisNet := strings.Split(listenaddr, "://")[0]
        lisAddr := strings.Split(listenaddr, "://")[1]
        tarNet := strings.Split(targetaddr, "://")[0]
        tarAddr := strings.Split(targetaddr, "://")[1]
+       lisIsUdp := strings.Contains(lisNet, "udp")
+       tarIsUdp := strings.Contains(tarNet, "udp")
+
+       if (!lisIsUdp && tarIsUdp) || (lisIsUdp && !tarIsUdp) {
+               callBack.ErrorMsg(targetaddr, listenaddr, ErrNetworkNoSupport)
+               return
+       }
 
        //尝试监听
        var listener net.Listener
@@ -158,13 +163,7 @@ func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func()
                        err = ErrNetworkNoSupport
                }
                if err != nil {
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: ErrorMsg,
-                               Msg:  err,
-                       }:
-                       }
+                       callBack.ErrorMsg(targetaddr, listenaddr, err)
                        return
                }
        }
@@ -176,43 +175,24 @@ func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func()
                        err = ErrNetworkNoSupport
                }
                if err != nil {
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: ErrorMsg,
-                               Msg:  err,
-                       }:
-                       }
+                       callBack.ErrorMsg(targetaddr, listenaddr, err)
                        return
                }
        }
 
-       closec := make(chan struct{})
        //初始化关闭方法
        closef = func() {
                listener.Close()
-               close(closec)
+               callBack.ClosMsg(targetaddr, listenaddr)
        }
 
        //返回监听地址
-       select {
-       default:
-       case msg_chan <- ForwardMsg{
-               Type: LisnMsg,
-               Msg:  listener.Addr(),
-       }:
-       }
+       callBack.LisnMsg(targetaddr, listenaddr)
 
        matchfunc := []func(ip net.IP) bool{}
        for _, cidr := range acceptCIDRs {
                if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: ErrorMsg,
-                               Msg:  err,
-                       }:
-                       }
+                       callBack.ErrorMsg(targetaddr, listenaddr, err)
                        return
                } else {
                        matchfunc = append(matchfunc, cidrx.Contains)
@@ -220,33 +200,23 @@ func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func()
        }
 
        //开始准备转发
-       go func(listener net.Listener, msg_chan chan ForwardMsg) {
-               defer close(msg_chan)
+       go func(listener net.Listener) {
                defer listener.Close()
 
                for {
                        proxyconn, err := listener.Accept()
+                       if errors.Is(err, ErrUdpConnected) {
+                               continue
+                       }
                        if err != nil {
                                //返回Accept错误
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: WarnMsg,
-                                       Msg:  err,
-                               }:
-                               }
+                               callBack.WarnMsg(targetaddr, listenaddr, pe.Join(ErrForwardAccept, err))
                                continue
                        }
 
                        host, _, err := net.SplitHostPort(proxyconn.RemoteAddr().String())
                        if err != nil {
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: WarnMsg,
-                                       Msg:  err,
-                               }:
-                               }
+                               callBack.WarnMsg(targetaddr, listenaddr, err)
                                continue
                        }
 
@@ -258,122 +228,157 @@ func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func()
                        }
                        if !accept {
                                //返回Deny
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: DenyMsg,
-                                       Msg:  proxyconn.RemoteAddr(),
-                               }:
-                               }
+                               callBack.DenyMsg(proxyconn.RemoteAddr(), targetaddr)
                                proxyconn.Close()
                                continue
                        }
                        //返回Accept
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: AcceptMsg,
-                               Msg:  proxyconn.RemoteAddr(),
-                       }:
-                       }
-
-                       select {
-                       case <-closec:
-                       default:
-                               break
-                       }
+                       conFin := callBack.AcceptMsg(proxyconn.RemoteAddr(), targetaddr)
 
                        go func() {
+                               defer conFin()
                                targetconn, err := net.Dial(tarNet, tarAddr)
                                if err != nil {
-                                       select {
-                                       default:
-                                       case msg_chan <- ForwardMsg{
-                                               Type: WarnMsg,
-                                               Msg:  err,
-                                       }:
-                                       }
+                                       callBack.WarnMsg(targetaddr, listenaddr, pe.Join(ErrForwardDail, err))
                                        return
                                }
 
-                               connBridge(proxyconn, targetconn, 20480)
-
-                               switch lisNet {
-                               case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
-                               case "udp", "udp4", "udp6":
-                                       time.Sleep(time.Second * 10)
-                               default:
-                               }
+                               // if !lisIsUdp && !tarIsUdp {
+                               //      tcp2tcpConnBridge(proxyconn, targetconn, 65535)
+                               // } else if lisIsUdp && tarIsUdp {
+                               connBridge(proxyconn, targetconn, 65535)
+                               // } else {
+                               //      connBridge(proxyconn, targetconn, 65535)
+                               // }
                        }()
                }
-       }(listener, msg_chan)
+       }(listener)
 
        return
 }
 
+var ErrUdpConnOverflow = errors.New(`ErrUdpConnOverflow`)
+
 type udpConn struct {
+       e         error
        conn      *net.UDPConn
        remoteAdd *net.UDPAddr
-       reader    io.Reader
+       ctx       context.Context
+       ctxCancel context.CancelFunc
+       buf       chan []byte
+       closef    func() error
 }
 
-func NewUdpConn(per []byte, remoteAdd *net.UDPAddr, conn *net.UDPConn) *udpConn {
-       return &udpConn{
-               conn:      conn,
-               remoteAdd: remoteAdd,
-               reader:    bytes.NewReader(per),
+func (t *udpConn) SetBuf(b []byte) {
+       tmp := make([]byte, len(b))
+       copy(tmp, b)
+       select {
+       case t.buf <- tmp:
+       default:
+               t.e = ErrUdpConnOverflow
+       }
+}
+func (t *udpConn) Read(b []byte) (n int, err error) {
+       select {
+       case tmp := <-t.buf:
+               n = copy(b, tmp)
+       case <-t.ctx.Done():
+               err = os.ErrDeadlineExceeded
+       }
+       return
+}
+func (t *udpConn) Write(b []byte) (n int, err error) {
+       select {
+       case <-t.ctx.Done():
+               err = os.ErrDeadlineExceeded
+       default:
+               n, err = t.conn.WriteToUDP(b, t.remoteAdd)
+               if err != nil {
+                       t.ctx.Done()
+               }
        }
+       return
 }
-func (t udpConn) Read(b []byte) (n int, err error) {
-       return t.reader.Read(b)
+func (t *udpConn) Close() error {
+       t.ctxCancel()
+       return t.closef()
 }
-func (t udpConn) Write(b []byte) (n int, err error) {
-       return t.conn.WriteToUDP(b, t.remoteAdd)
+func (t *udpConn) LocalAddr() net.Addr  { return t.conn.LocalAddr() }
+func (t *udpConn) RemoteAddr() net.Addr { return t.remoteAdd }
+func (t *udpConn) SetDeadline(b time.Time) error {
+       time.AfterFunc(time.Until(b), func() {
+               t.Close()
+       })
+       return nil
 }
-func (t udpConn) Close() error                       { return nil }
-func (t udpConn) LocalAddr() net.Addr                { return t.conn.LocalAddr() }
-func (t udpConn) RemoteAddr() net.Addr               { return t.remoteAdd }
-func (t udpConn) SetDeadline(b time.Time) error      { return t.conn.SetDeadline(b) }
-func (t udpConn) SetReadDeadline(b time.Time) error  { return t.conn.SetReadDeadline(b) }
-func (t udpConn) SetWriteDeadline(b time.Time) error { return t.conn.SetWriteDeadline(b) }
+func (t *udpConn) SetReadDeadline(b time.Time) error  { return t.SetDeadline(b) }
+func (t *udpConn) SetWriteDeadline(b time.Time) error { return t.SetDeadline(b) }
 
 type udpLis struct {
-       addr    net.Addr
-       conn    *net.UDPConn
-       pkgSize int
+       udpAddr *net.UDPAddr
+       c       <-chan *udpConn
+       closef  func() error
 }
 
-func NewUdpListener(network, listenaddr string, pkgSize ...int) (*udpLis, error) {
+var ErrUdpConnected error = errors.New("ErrUdpConnected")
+
+func NewUdpListener(network, listenaddr string) (*udpLis, error) {
        udpAddr, err := net.ResolveUDPAddr(network, listenaddr)
        if err != nil {
                return nil, err
        }
-       pkgSize = append(pkgSize, 1500)
        if conn, err := net.ListenUDP(network, udpAddr); err != nil {
                return nil, err
        } else {
-               return &udpLis{
-                       addr:    udpAddr,
-                       conn:    conn,
-                       pkgSize: pkgSize[0],
-               }, nil
+               c := make(chan *udpConn, 10)
+               lis := &udpLis{
+                       udpAddr: udpAddr,
+                       c:       c,
+                       closef: func() error {
+                               return conn.Close()
+                       },
+               }
+               go func() {
+                       var link psync.MapG[string, *udpConn]
+                       buf := make([]byte, humanize.MByte)
+                       for {
+                               n, remoteAdd, e := conn.ReadFromUDP(buf)
+                               if e != nil {
+                                       c <- &udpConn{e: e}
+                                       return
+                               }
+                               if udpc, ok := link.Load(remoteAdd.String()); ok {
+                                       udpc.SetBuf(buf[:n])
+                                       c <- &udpConn{e: ErrUdpConnected}
+                               } else {
+                                       udpc := &udpConn{
+                                               conn:      conn,
+                                               remoteAdd: remoteAdd,
+                                               closef: func() error {
+                                                       link.Delete(remoteAdd.String())
+                                                       return nil
+                                               },
+                                               buf: make(chan []byte, 5),
+                                       }
+                                       udpc.ctx, udpc.ctxCancel = context.WithTimeout(context.Background(), time.Second*30)
+                                       udpc.SetBuf(buf[:n])
+                                       link.Store(remoteAdd.String(), udpc)
+                                       c <- udpc
+                               }
+                       }
+               }()
+               return lis, nil
        }
 }
-func (t udpLis) Accept() (net.Conn, error) {
-       buf := make([]byte, t.pkgSize)
-       if n, remoteAdd, e := t.conn.ReadFromUDP(buf); e != nil {
-               return nil, e
-       } else if n == t.pkgSize {
-               return nil, ErrUdpOverflow
-       } else {
-               return NewUdpConn(buf[:n], remoteAdd, t.conn), nil
-       }
+func (t *udpLis) Accept() (net.Conn, error) {
+       udpc := <-t.c
+       return udpc, udpc.e
 }
-func (t udpLis) Close() error {
-       return t.conn.Close()
+func (t *udpLis) Close() error {
+       return t.closef()
 }
-func (t udpLis) Addr() net.Addr {
-       return t.addr
+func (t *udpLis) Addr() net.Addr {
+       return t.udpAddr
 }
 
 // type ConnPool struct {