]> 127.0.0.1 Git - part/.git/commitdiff
1 v0.28.20240414033603
authorqydysky <qydysky@foxmail.com>
Sun, 14 Apr 2024 03:30:42 +0000 (03:30 +0000)
committerqydysky <qydysky@foxmail.com>
Sun, 14 Apr 2024 03:30:42 +0000 (03:30 +0000)
Net.go

diff --git a/Net.go b/Net.go
index d586e5af609f27a381708dacda5cb10a98657022..5446d4a19179a62f9c676e58f10b605bc9faa961 100644 (file)
--- a/Net.go
+++ b/Net.go
@@ -3,6 +3,7 @@ package part
 import (
        "bytes"
        "errors"
+       "io"
        "net"
        "net/url"
        "os/exec"
@@ -13,6 +14,13 @@ import (
 
        "github.com/miekg/dns"
        pool "github.com/qydysky/part/pool"
+       psync "github.com/qydysky/part/sync"
+)
+
+var (
+       ErrDnsNoAnswer      = errors.New("ErrDnsNoAnswer")
+       ErrNetworkNoSupport = errors.New("ErrNetworkNoSupport")
+       ErrUdpOverflow      = errors.New("ErrUdpOverflow") // default:1500 set higher pkgSize at NewUdpListener
 )
 
 type netl struct {
@@ -45,7 +53,7 @@ func (this *netl) Nslookup(target string) error {
                return err
        }
        if len(r.Answer) == 0 {
-               return errors.New("no answer")
+               return ErrDnsNoAnswer
        }
 
        this.RV = append(this.RV, dns.Field(r.Answer[0], 1))
@@ -86,8 +94,6 @@ func connBridge(a, b net.Conn) {
        wg.Add(2)
        go func() {
                defer func() {
-                       a.Close()
-                       b.Close()
                        fin <- true
                        wg.Done()
                }()
@@ -99,20 +105,17 @@ func connBridge(a, b net.Conn) {
                        case <-fin:
                                return
                        default:
-                               n, err := a.Read(buf)
-
-                               if err != nil {
+                               if n, err := a.Read(buf); err != nil {
+                                       return
+                               } else if _, err = b.Write(buf[:n]); err != nil {
                                        return
                                }
-                               b.Write(buf[:n])
                        }
                }
        }()
 
        go func() {
                defer func() {
-                       a.Close()
-                       b.Close()
                        fin <- true
                        wg.Done()
                }()
@@ -124,12 +127,11 @@ func connBridge(a, b net.Conn) {
                        case <-fin:
                                return
                        default:
-                               n, err := b.Read(buf)
-
-                               if err != nil {
+                               if n, err := b.Read(buf); err != nil {
+                                       return
+                               } else if _, err = a.Write(buf[:n]); err != nil {
                                        return
                                }
-                               a.Write(buf[:n])
                        }
                }
        }()
@@ -137,22 +139,54 @@ func connBridge(a, b net.Conn) {
        wg.Wait()
 }
 
-// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
-func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
-       //初始化消息通道
+func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
        msg_chan = make(chan ForwardMsg, 1000)
 
+       lisNet := strings.Split(listenaddr, "://")[0]
+       lisAddr := strings.Split(listenaddr, "://")[1]
+       tarNet := strings.Split(targetaddr, "://")[0]
+       tarAddr := strings.Split(targetaddr, "://")[1]
+
        //尝试监听
-       listener, err := net.Listen(network, listenaddr)
-       if err != nil {
-               select {
+       var listener net.Listener
+       {
+               var err error
+               switch lisNet {
+               case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
+                       listener, err = net.Listen(lisNet, lisAddr)
+               case "udp", "udp4", "udp6":
+                       listener, err = NewUdpListener(lisNet, lisAddr)
                default:
-               case msg_chan <- ForwardMsg{
-                       Type: ErrorMsg,
-                       Msg:  err,
-               }:
+                       err = ErrNetworkNoSupport
+               }
+               if err != nil {
+                       select {
+                       default:
+                       case msg_chan <- ForwardMsg{
+                               Type: ErrorMsg,
+                               Msg:  err,
+                       }:
+                       }
+                       return
+               }
+       }
+       {
+               var err error
+               switch lisNet {
+               case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6", "unix", "unixgram", "unixpacket":
+               default:
+                       err = ErrNetworkNoSupport
+               }
+               if err != nil {
+                       select {
+                       default:
+                       case msg_chan <- ForwardMsg{
+                               Type: ErrorMsg,
+                               Msg:  err,
+                       }:
+                       }
+                       return
                }
-               return
        }
 
        closec := make(chan struct{})
@@ -172,7 +206,6 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
        }
 
        matchfunc := []func(ip net.IP) bool{}
-
        for _, cidr := range acceptCIDRs {
                if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
                        select {
@@ -189,10 +222,14 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
        }
 
        //开始准备转发
-       go func(listener net.Listener, targetaddr, network string, msg_chan chan ForwardMsg) {
+       go func(listener net.Listener, msg_chan chan ForwardMsg) {
                defer close(msg_chan)
                defer listener.Close()
 
+               tarConnPool := NewConnPool(100, func() *net.Conn {
+                       conn, _ := net.Dial(tarNet, tarAddr)
+                       return &conn
+               })
                for {
                        proxyconn, err := listener.Accept()
                        if err != nil {
@@ -211,7 +248,7 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
 
                        var deny bool
                        for i := 0; !deny && i < len(matchfunc); i++ {
-                               deny = deny && !matchfunc[i](ip)
+                               deny = deny || !matchfunc[i](ip)
                        }
                        if deny {
                                //返回Deny
@@ -241,216 +278,319 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
                                break
                        }
 
-                       targetconn, err := net.Dial(network, targetaddr)
-                       if err != nil {
-                               select {
+                       go func() {
+                               targetconn, putBackf := tarConnPool.Get(tarAddr)
+                               defer putBackf()
+
+                               connBridge(proxyconn, targetconn)
+
+                               switch lisNet {
+                               case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
+                               case "udp", "udp4", "udp6":
+                                       time.Sleep(time.Second * 10)
                                default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: ErrorMsg,
-                                       Msg:  err,
-                               }:
                                }
                                proxyconn.Close()
-                               continue
-                       }
-
-                       go connBridge(proxyconn, targetconn)
+                               targetconn.Close()
+                       }()
                }
-       }(listener, targetaddr, network, msg_chan)
+       }(listener, msg_chan)
 
        return
 }
 
-func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
-       //初始化消息通道
-       msg_chan = make(chan ForwardMsg, 1000)
+type udpConn struct {
+       conn      *net.UDPConn
+       remoteAdd *net.UDPAddr
+       reader    io.Reader
+}
 
-       lisAddr := func(network, listenaddr string) (*net.UDPConn, error) {
-               if udpAddr, err := net.ResolveUDPAddr(network, listenaddr); err != nil {
-                       return nil, err
-               } else if conn, err := net.ListenUDP(network, udpAddr); err != nil {
-                       return nil, err
-               } else {
-                       return conn, nil
-               }
+func NewUdpConn(per []byte, remoteAdd *net.UDPAddr, conn *net.UDPConn) *udpConn {
+       return &udpConn{
+               conn:      conn,
+               remoteAdd: remoteAdd,
+               reader:    io.MultiReader(bytes.NewReader(per), conn),
        }
+}
+func (t udpConn) Read(b []byte) (n int, err error) {
+       return t.reader.Read(b)
+}
+func (t udpConn) Write(b []byte) (n int, err error) {
+       return t.conn.WriteToUDP(b, t.remoteAdd)
+}
+func (t udpConn) Close() error                       { return nil }
+func (t udpConn) LocalAddr() net.Addr                { return t.conn.LocalAddr() }
+func (t udpConn) RemoteAddr() net.Addr               { return t.remoteAdd }
+func (t udpConn) SetDeadline(b time.Time) error      { return t.conn.SetDeadline(b) }
+func (t udpConn) SetReadDeadline(b time.Time) error  { return t.conn.SetReadDeadline(b) }
+func (t udpConn) SetWriteDeadline(b time.Time) error { return t.conn.SetWriteDeadline(b) }
+
+type udpLis struct {
+       addr    net.Addr
+       conn    *net.UDPConn
+       pkgSize int
+}
 
-       targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+func NewUdpListener(network, listenaddr string, pkgSize ...int) (*udpLis, error) {
+       udpAddr, err := net.ResolveUDPAddr(network, listenaddr)
        if err != nil {
-               select {
-               default:
-               case msg_chan <- ForwardMsg{
-                       Type: ErrorMsg,
-                       Msg:  err,
-               }:
-               }
-               return
+               return nil, err
        }
-
-       serConn, err := lisAddr(network, listenaddr)
-       if err != nil {
-               select {
-               default:
-               case msg_chan <- ForwardMsg{
-                       Type: ErrorMsg,
-                       Msg:  err,
-               }:
-               }
-               return
+       pkgSize = append(pkgSize, 1500)
+       if conn, err := net.ListenUDP(network, udpAddr); err != nil {
+               return nil, err
+       } else {
+               return &udpLis{
+                       addr:    udpAddr,
+                       conn:    conn,
+                       pkgSize: pkgSize[0],
+               }, nil
        }
-
-       closec := make(chan struct{})
-       //初始化关闭方法
-       closef = func() {
-               serConn.Close()
-               close(closec)
+}
+func (t udpLis) Accept() (net.Conn, error) {
+       buf := make([]byte, t.pkgSize)
+       if n, remoteAdd, e := t.conn.ReadFromUDP(buf); e != nil {
+               return nil, e
+       } else if n == t.pkgSize {
+               return nil, ErrUdpOverflow
+       } else {
+               return NewUdpConn(buf[:n], remoteAdd, t.conn), nil
        }
+}
+func (t udpLis) Close() error {
+       return t.conn.Close()
+}
+func (t udpLis) Addr() net.Addr {
+       return t.addr
+}
 
-       //返回监听端口
-       select {
-       default:
-       case msg_chan <- ForwardMsg{
-               Type: LisnMsg,
-               Msg:  serConn.LocalAddr(),
-       }:
-       }
+type ConnPool struct {
+       connMap *psync.Map
+       pool    *pool.Buf[ConnPoolItem]
+}
 
-       matchfunc := []func(ip net.IP) bool{}
+type ConnPoolItem struct {
+       p      *net.Conn
+       pooled bool
+}
 
-       for _, cidr := range acceptCIDRs {
-               if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: ErrorMsg,
-                               Msg:  err,
-                       }:
-                       }
-                       return
-               } else {
-                       matchfunc = append(matchfunc, cidrx.Contains)
-               }
+func NewConnPool(size int, genConn func() *net.Conn) *ConnPool {
+       return &ConnPool{
+               connMap: new(psync.Map),
+               pool: pool.New(pool.PoolFunc[ConnPoolItem]{
+                       New: func() *ConnPoolItem {
+                               return &ConnPoolItem{p: genConn()}
+                       },
+                       InUse: func(u *ConnPoolItem) bool {
+                               return !u.pooled
+                       },
+                       Reuse: func(u *ConnPoolItem) *ConnPoolItem {
+                               u.pooled = false
+                               return u
+                       },
+                       Pool: func(u *ConnPoolItem) *ConnPoolItem {
+                               u.pooled = true
+                               return u
+                       },
+               }, size),
        }
+}
 
-       //开始准备转发
-       go func(serConn *net.UDPConn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
-               defer close(msg_chan)
-               defer serConn.Close()
-
-               type UDPConn struct {
-                       p      *net.UDPConn
-                       pooled bool
-               }
-
-               var (
-                       connMap = make(map[string]*UDPConn)
-                       udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
-                               New: func() *UDPConn {
-                                       conn, _ := net.ListenUDP(network, nil)
-                                       return &UDPConn{p: conn}
-                               },
-                               InUse: func(u *UDPConn) bool {
-                                       return !u.pooled
-                               },
-                               Reuse: func(u *UDPConn) *UDPConn {
-                                       u.pooled = false
-                                       return u
-                               },
-                               Pool: func(u *UDPConn) *UDPConn {
-                                       u.pooled = true
-                                       return u
-                               },
-                       }, 100)
-               )
-               genConn := func(cliAddr *net.UDPAddr) *UDPConn {
-                       if conn, ok := connMap[cliAddr.String()]; ok {
-                               return conn
-                       } else {
-                               conn = udpPool.Get()
-                               connMap[cliAddr.String()] = conn
-                               return conn
-                       }
-               }
-
-               buf := make([]byte, 20480)
-               for {
-                       n, cliAddr, err := serConn.ReadFromUDP(buf)
-                       if err != nil {
-                               //返回Accept错误
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: ErrorMsg,
-                                       Msg:  err,
-                               }:
-                               }
-                               continue
-                       }
-
-                       var deny bool
-                       for i := 0; !deny && i < len(matchfunc); i++ {
-                               deny = deny && !matchfunc[i](cliAddr.IP)
-                       }
-                       if deny {
-                               //返回Deny
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: DenyMsg,
-                                       Msg:  net.Addr(cliAddr),
-                               }:
-                               }
-                               continue
-                       }
-
-                       //返回Accept
-                       select {
-                       default:
-                       case msg_chan <- ForwardMsg{
-                               Type: AcceptMsg,
-                               Msg:  net.Addr(cliAddr),
-                       }:
-                       }
-
-                       select {
-                       case <-closec:
-                       default:
-                               break
-                       }
-
-                       targetConn := genConn(cliAddr)
-                       if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
-                               //返回Accept错误
-                               select {
-                               default:
-                               case msg_chan <- ForwardMsg{
-                                       Type: ErrorMsg,
-                                       Msg:  err,
-                               }:
-                               }
-                       } else {
-                               go func() {
-                                       defer udpPool.Put(targetConn)
-
-                                       buf := make([]byte, 20480)
-                                       for {
-                                               targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
-                                               n, _, e := targetConn.p.ReadFromUDP(buf)
-                                               if e != nil {
-                                                       return
-                                               }
-                                               if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
-                                                       return
-                                               }
-                                       }
-                               }()
-                       }
+func (t *ConnPool) Get(id string) (i net.Conn, putBack func()) {
+       if conn, ok := t.connMap.LoadV(id).(*ConnPoolItem); ok {
+               return *conn.p, func() {}
+       } else {
+               conn = t.pool.Get()
+               t.connMap.Store(id, conn)
+               return *conn.p, func() {
+                       t.connMap.Delete(id)
+                       t.pool.Put(conn)
                }
-       }(serConn, targetAddr, msg_chan)
-
-       return
+       }
 }
 
+// func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
+//     //初始化消息通道
+//     msg_chan = make(chan ForwardMsg, 1000)
+
+//     targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+//     if err != nil {
+//             select {
+//             default:
+//             case msg_chan <- ForwardMsg{
+//                     Type: ErrorMsg,
+//                     Msg:  err,
+//             }:
+//             }
+//             return
+//     }
+
+//     serConn, err := udpListener(network, listenaddr)
+//     if err != nil {
+//             select {
+//             default:
+//             case msg_chan <- ForwardMsg{
+//                     Type: ErrorMsg,
+//                     Msg:  err,
+//             }:
+//             }
+//             return
+//     }
+
+//     closec := make(chan struct{})
+//     //初始化关闭方法
+//     closef = func() {
+//             serConn.Close()
+//             close(closec)
+//     }
+
+//     //返回监听端口
+//     select {
+//     default:
+//     case msg_chan <- ForwardMsg{
+//             Type: LisnMsg,
+//             Msg:  serConn.LocalAddr(),
+//     }:
+//     }
+
+//     matchfunc := []func(ip net.IP) bool{}
+
+//     for _, cidr := range acceptCIDRs {
+//             if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
+//                     select {
+//                     default:
+//                     case msg_chan <- ForwardMsg{
+//                             Type: ErrorMsg,
+//                             Msg:  err,
+//                     }:
+//                     }
+//                     return
+//             } else {
+//                     matchfunc = append(matchfunc, cidrx.Contains)
+//             }
+//     }
+
+//     //开始准备转发
+//     go func(serConn net.Conn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
+//             defer close(msg_chan)
+//             defer serConn.Close()
+
+//             type UDPConn struct {
+//                     p      *net.UDPConn
+//                     pooled bool
+//             }
+
+//             var (
+//                     connMap = make(map[string]*UDPConn)
+//                     udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
+//                             New: func() *UDPConn {
+//                                     conn, _ := net.ListenUDP(network, nil)
+//                                     return &UDPConn{p: conn}
+//                             },
+//                             InUse: func(u *UDPConn) bool {
+//                                     return !u.pooled
+//                             },
+//                             Reuse: func(u *UDPConn) *UDPConn {
+//                                     u.pooled = false
+//                                     return u
+//                             },
+//                             Pool: func(u *UDPConn) *UDPConn {
+//                                     u.pooled = true
+//                                     return u
+//                             },
+//                     }, 100)
+//             )
+//             genConn := func(cliAddr *net.UDPAddr) *UDPConn {
+//                     if conn, ok := connMap[cliAddr.String()]; ok {
+//                             return conn
+//                     } else {
+//                             conn = udpPool.Get()
+//                             connMap[cliAddr.String()] = conn
+//                             return conn
+//                     }
+//             }
+
+//             buf := make([]byte, 20480)
+//             for {
+//                     n, err := serConn.Read(buf)
+//                     if err != nil {
+//                             //返回Accept错误
+//                             select {
+//                             default:
+//                             case msg_chan <- ForwardMsg{
+//                                     Type: ErrorMsg,
+//                                     Msg:  err,
+//                             }:
+//                             }
+//                             continue
+//                     }
+
+//                     ip := net.ParseIP(strings.Split(serConn.RemoteAddr().String(), ":")[0])
+//                     var deny bool
+//                     for i := 0; !deny && i < len(matchfunc); i++ {
+//                             deny = deny && !matchfunc[i](ip)
+//                     }
+//                     if deny {
+//                             //返回Deny
+//                             select {
+//                             default:
+//                             case msg_chan <- ForwardMsg{
+//                                     Type: DenyMsg,
+//                                     Msg:  serConn.RemoteAddr(),
+//                             }:
+//                             }
+//                             serConn.Close()
+//                             continue
+//                     }
+
+//                     //返回Accept
+//                     select {
+//                     default:
+//                     case msg_chan <- ForwardMsg{
+//                             Type: AcceptMsg,
+//                             Msg:  serConn.RemoteAddr(),
+//                     }:
+//                     }
+
+//                     select {
+//                     case <-closec:
+//                     default:
+//                             break
+//                     }
+
+//                     targetConn := genConn(cliAddr)
+//                     if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
+//                             //返回Accept错误
+//                             select {
+//                             default:
+//                             case msg_chan <- ForwardMsg{
+//                                     Type: ErrorMsg,
+//                                     Msg:  err,
+//                             }:
+//                             }
+//                     } else {
+//                             go func() {
+//                                     defer udpPool.Put(targetConn)
+
+//                                     buf := make([]byte, 20480)
+//                                     for {
+//                                             targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
+//                                             n, _, e := targetConn.p.ReadFromUDP(buf)
+//                                             if e != nil {
+//                                                     return
+//                                             }
+//                                             if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
+//                                                     return
+//                                             }
+//                                     }
+//                             }()
+//                     }
+//             }
+//     }(serConn, targetAddr, msg_chan)
+
+//     return
+// }
+
 func (this *netl) GetLocalDns() error {
        if runtime.GOOS == "windows" {
                cmd := exec.Command("nslookup", "127.0.0.1")