]> 127.0.0.1 Git - part/.git/commitdiff
1 v0.28.20240411191949
authorqydysky <qydysky@foxmail.com>
Thu, 11 Apr 2024 19:13:54 +0000 (19:13 +0000)
committerqydysky <qydysky@foxmail.com>
Thu, 11 Apr 2024 19:13:54 +0000 (19:13 +0000)
Net.go

diff --git a/Net.go b/Net.go
index 9b41d440db7fb38665cc60673f51469743ab27e7..87a84675d7a0e73fa62730071524765b1d86e079 100644 (file)
--- a/Net.go
+++ b/Net.go
@@ -12,6 +12,7 @@ import (
        "time"
 
        "github.com/miekg/dns"
+       pool "github.com/qydysky/part/pool"
 )
 
 type netl struct {
@@ -77,6 +78,66 @@ type ForwardMsg struct {
        Msg  interface{}
 }
 
+// 桥
+func connBridge(a, b net.Conn) {
+       fin := make(chan bool, 1)
+       var wg sync.WaitGroup
+
+       wg.Add(2)
+       go func() {
+               defer func() {
+                       a.Close()
+                       b.Close()
+                       fin <- true
+                       wg.Done()
+               }()
+
+               buf := make([]byte, 20480)
+
+               for {
+                       select {
+                       case <-fin:
+                               return
+                       default:
+                               n, err := a.Read(buf)
+
+                               if err != nil {
+                                       return
+                               }
+                               b.Write(buf[:n])
+                       }
+               }
+       }()
+
+       go func() {
+               defer func() {
+                       a.Close()
+                       b.Close()
+                       fin <- true
+                       wg.Done()
+               }()
+
+               buf := make([]byte, 20480)
+
+               for {
+                       select {
+                       case <-fin:
+                               return
+                       default:
+                               n, err := b.Read(buf)
+
+                               if err != nil {
+                                       return
+                               }
+                               a.Write(buf[:n])
+                       }
+               }
+       }()
+
+       wg.Wait()
+}
+
+// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
 func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
        //初始化消息通道
        msg_chan = make(chan ForwardMsg, 1000)
@@ -132,65 +193,6 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
                defer close(msg_chan)
                defer listener.Close()
 
-               //tcp 桥
-               tcpBridge2 := func(a, b net.Conn) {
-                       fin := make(chan bool, 1)
-                       var wg sync.WaitGroup
-
-                       wg.Add(2)
-                       go func() {
-                               defer func() {
-                                       a.Close()
-                                       b.Close()
-                                       fin <- true
-                                       wg.Done()
-                               }()
-
-                               buf := make([]byte, 20480)
-
-                               for {
-                                       select {
-                                       case <-fin:
-                                               return
-                                       default:
-                                               n, err := a.Read(buf)
-
-                                               if err != nil {
-                                                       return
-                                               }
-                                               b.Write(buf[:n])
-                                       }
-                               }
-                       }()
-
-                       go func() {
-                               defer func() {
-                                       a.Close()
-                                       b.Close()
-                                       fin <- true
-                                       wg.Done()
-                               }()
-
-                               buf := make([]byte, 20480)
-
-                               for {
-                                       select {
-                                       case <-fin:
-                                               return
-                                       default:
-                                               n, err := b.Read(buf)
-
-                                               if err != nil {
-                                                       return
-                                               }
-                                               a.Write(buf[:n])
-                                       }
-                               }
-                       }()
-
-                       wg.Wait()
-               }
-
                for {
                        proxyconn, err := listener.Accept()
                        if err != nil {
@@ -252,13 +254,203 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos
                                continue
                        }
 
-                       go tcpBridge2(proxyconn, targetconn)
+                       go connBridge(proxyconn, targetconn)
                }
        }(listener, targetaddr, network, msg_chan)
 
        return
 }
 
+func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
+       //初始化消息通道
+       msg_chan = make(chan ForwardMsg, 1000)
+
+       lisAddr := func(network, listenaddr string) (*net.UDPConn, error) {
+               if udpAddr, err := net.ResolveUDPAddr(network, listenaddr); err != nil {
+                       return nil, err
+               } else if conn, err := net.ListenUDP(network, udpAddr); err != nil {
+                       return nil, err
+               } else {
+                       return conn, nil
+               }
+       }
+
+       targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+       if err != nil {
+               select {
+               default:
+               case msg_chan <- ForwardMsg{
+                       Type: ErrorMsg,
+                       Msg:  err,
+               }:
+               }
+               return
+       }
+
+       serConn, err := lisAddr(network, listenaddr)
+       if err != nil {
+               select {
+               default:
+               case msg_chan <- ForwardMsg{
+                       Type: ErrorMsg,
+                       Msg:  err,
+               }:
+               }
+               return
+       }
+
+       closec := make(chan struct{})
+       //初始化关闭方法
+       closef = func() {
+               serConn.Close()
+               close(closec)
+       }
+
+       //返回监听端口
+       select {
+       default:
+       case msg_chan <- ForwardMsg{
+               Type: PortMsg,
+               Msg:  serConn.LocalAddr().(*net.UDPAddr).Port,
+       }:
+       }
+
+       matchfunc := []func(ip net.IP) bool{}
+
+       for _, cidr := range acceptCIDRs {
+               if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
+                       select {
+                       default:
+                       case msg_chan <- ForwardMsg{
+                               Type: ErrorMsg,
+                               Msg:  err,
+                       }:
+                       }
+                       return
+               } else {
+                       matchfunc = append(matchfunc, cidrx.Contains)
+               }
+       }
+
+       //开始准备转发
+       go func(serConn *net.UDPConn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
+               defer close(msg_chan)
+               defer serConn.Close()
+
+               type UDPConn struct {
+                       p      *net.UDPConn
+                       pooled bool
+               }
+
+               var (
+                       connMap = make(map[string]*UDPConn)
+                       udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
+                               New: func() *UDPConn {
+                                       conn, _ := net.ListenUDP(network, nil)
+                                       return &UDPConn{p: conn}
+                               },
+                               InUse: func(u *UDPConn) bool {
+                                       return !u.pooled
+                               },
+                               Reuse: func(u *UDPConn) *UDPConn {
+                                       u.pooled = false
+                                       return u
+                               },
+                               Pool: func(u *UDPConn) *UDPConn {
+                                       u.pooled = true
+                                       return u
+                               },
+                       }, 100)
+               )
+               genConn := func(cliAddr *net.UDPAddr) *UDPConn {
+                       if conn, ok := connMap[cliAddr.String()]; ok {
+                               return conn
+                       } else {
+                               conn = udpPool.Get()
+                               connMap[cliAddr.String()] = conn
+                               return conn
+                       }
+               }
+
+               buf := make([]byte, 20480)
+               for {
+                       n, cliAddr, err := serConn.ReadFromUDP(buf)
+                       if err != nil {
+                               //返回Accept错误
+                               select {
+                               default:
+                               case msg_chan <- ForwardMsg{
+                                       Type: ErrorMsg,
+                                       Msg:  err,
+                               }:
+                               }
+                               continue
+                       }
+
+                       var accpet bool
+                       for i := 0; i < len(matchfunc); i++ {
+                               accpet = accpet || matchfunc[i](cliAddr.IP)
+                       }
+                       if !accpet {
+                               //返回Deny
+                               select {
+                               default:
+                               case msg_chan <- ForwardMsg{
+                                       Type: DenyMsg,
+                                       Msg:  net.Addr(cliAddr),
+                               }:
+                               }
+                               continue
+                       }
+
+                       //返回Accept
+                       select {
+                       default:
+                       case msg_chan <- ForwardMsg{
+                               Type: AcceptMsg,
+                               Msg:  net.Addr(cliAddr),
+                       }:
+                       }
+
+                       select {
+                       case <-closec:
+                       default:
+                               break
+                       }
+
+                       targetConn := genConn(cliAddr)
+                       if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
+                               //返回Accept错误
+                               select {
+                               default:
+                               case msg_chan <- ForwardMsg{
+                                       Type: ErrorMsg,
+                                       Msg:  err,
+                               }:
+                               }
+                       } else {
+                               go func() {
+                                       defer udpPool.Put(targetConn)
+
+                                       buf := make([]byte, 20480)
+                                       for {
+                                               targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
+                                               n, _, e := targetConn.p.ReadFromUDP(buf)
+                                               if e != nil {
+                                                       return
+                                               }
+                                               if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
+                                                       return
+                                               }
+                                       }
+                               }()
+                       }
+               }
+       }(serConn, targetAddr, msg_chan)
+
+       return
+}
+
 func (this *netl) GetLocalDns() error {
        if runtime.GOOS == "windows" {
                cmd := exec.Command("nslookup", "127.0.0.1")