From: qydysky Date: Sun, 14 Apr 2024 03:30:42 +0000 (+0000) Subject: 1 X-Git-Tag: v0.28.20240414033603 X-Git-Url: http://127.0.0.1:8081/?a=commitdiff_plain;h=f2258a47c49b1cacdf683f24e812d176d611d996;p=part%2F.git 1 --- diff --git a/Net.go b/Net.go index d586e5a..5446d4a 100644 --- a/Net.go +++ b/Net.go @@ -3,6 +3,7 @@ package part import ( "bytes" "errors" + "io" "net" "net/url" "os/exec" @@ -13,6 +14,13 @@ import ( "github.com/miekg/dns" pool "github.com/qydysky/part/pool" + psync "github.com/qydysky/part/sync" +) + +var ( + ErrDnsNoAnswer = errors.New("ErrDnsNoAnswer") + ErrNetworkNoSupport = errors.New("ErrNetworkNoSupport") + ErrUdpOverflow = errors.New("ErrUdpOverflow") // default:1500 set higher pkgSize at NewUdpListener ) type netl struct { @@ -45,7 +53,7 @@ func (this *netl) Nslookup(target string) error { return err } if len(r.Answer) == 0 { - return errors.New("no answer") + return ErrDnsNoAnswer } this.RV = append(this.RV, dns.Field(r.Answer[0], 1)) @@ -86,8 +94,6 @@ func connBridge(a, b net.Conn) { wg.Add(2) go func() { defer func() { - a.Close() - b.Close() fin <- true wg.Done() }() @@ -99,20 +105,17 @@ func connBridge(a, b net.Conn) { case <-fin: return default: - n, err := a.Read(buf) - - if err != nil { + if n, err := a.Read(buf); err != nil { + return + } else if _, err = b.Write(buf[:n]); err != nil { return } - b.Write(buf[:n]) } } }() go func() { defer func() { - a.Close() - b.Close() fin <- true wg.Done() }() @@ -124,12 +127,11 @@ func connBridge(a, b net.Conn) { case <-fin: return default: - n, err := b.Read(buf) - - if err != nil { + if n, err := b.Read(buf); err != nil { + return + } else if _, err = a.Write(buf[:n]); err != nil { return } - a.Write(buf[:n]) } } }() @@ -137,22 +139,54 @@ func connBridge(a, b net.Conn) { wg.Wait() } -// "tcp", "tcp4", "tcp6", "unix" or "unixpacket" -func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) { - //初始化消息通道 +func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) { msg_chan = make(chan ForwardMsg, 1000) + lisNet := strings.Split(listenaddr, "://")[0] + lisAddr := strings.Split(listenaddr, "://")[1] + tarNet := strings.Split(targetaddr, "://")[0] + tarAddr := strings.Split(targetaddr, "://")[1] + //尝试监听 - listener, err := net.Listen(network, listenaddr) - if err != nil { - select { + var listener net.Listener + { + var err error + switch lisNet { + case "tcp", "tcp4", "tcp6", "unix", "unixpacket": + listener, err = net.Listen(lisNet, lisAddr) + case "udp", "udp4", "udp6": + listener, err = NewUdpListener(lisNet, lisAddr) default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: + err = ErrNetworkNoSupport + } + if err != nil { + select { + default: + case msg_chan <- ForwardMsg{ + Type: ErrorMsg, + Msg: err, + }: + } + return + } + } + { + var err error + switch lisNet { + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6", "unix", "unixgram", "unixpacket": + default: + err = ErrNetworkNoSupport + } + if err != nil { + select { + default: + case msg_chan <- ForwardMsg{ + Type: ErrorMsg, + Msg: err, + }: + } + return } - return } closec := make(chan struct{}) @@ -172,7 +206,6 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos } matchfunc := []func(ip net.IP) bool{} - for _, cidr := range acceptCIDRs { if _, cidrx, err := net.ParseCIDR(cidr); err != nil { select { @@ -189,10 +222,14 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos } //开始准备转发 - go func(listener net.Listener, targetaddr, network string, msg_chan chan ForwardMsg) { + go func(listener net.Listener, msg_chan chan ForwardMsg) { defer close(msg_chan) defer listener.Close() + tarConnPool := NewConnPool(100, func() *net.Conn { + conn, _ := net.Dial(tarNet, tarAddr) + return &conn + }) for { proxyconn, err := listener.Accept() if err != nil { @@ -211,7 +248,7 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos var deny bool for i := 0; !deny && i < len(matchfunc); i++ { - deny = deny && !matchfunc[i](ip) + deny = deny || !matchfunc[i](ip) } if deny { //返回Deny @@ -241,216 +278,319 @@ func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (clos break } - targetconn, err := net.Dial(network, targetaddr) - if err != nil { - select { + go func() { + targetconn, putBackf := tarConnPool.Get(tarAddr) + defer putBackf() + + connBridge(proxyconn, targetconn) + + switch lisNet { + case "tcp", "tcp4", "tcp6", "unix", "unixpacket": + case "udp", "udp4", "udp6": + time.Sleep(time.Second * 10) default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: } proxyconn.Close() - continue - } - - go connBridge(proxyconn, targetconn) + targetconn.Close() + }() } - }(listener, targetaddr, network, msg_chan) + }(listener, msg_chan) return } -func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) { - //初始化消息通道 - msg_chan = make(chan ForwardMsg, 1000) +type udpConn struct { + conn *net.UDPConn + remoteAdd *net.UDPAddr + reader io.Reader +} - lisAddr := func(network, listenaddr string) (*net.UDPConn, error) { - if udpAddr, err := net.ResolveUDPAddr(network, listenaddr); err != nil { - return nil, err - } else if conn, err := net.ListenUDP(network, udpAddr); err != nil { - return nil, err - } else { - return conn, nil - } +func NewUdpConn(per []byte, remoteAdd *net.UDPAddr, conn *net.UDPConn) *udpConn { + return &udpConn{ + conn: conn, + remoteAdd: remoteAdd, + reader: io.MultiReader(bytes.NewReader(per), conn), } +} +func (t udpConn) Read(b []byte) (n int, err error) { + return t.reader.Read(b) +} +func (t udpConn) Write(b []byte) (n int, err error) { + return t.conn.WriteToUDP(b, t.remoteAdd) +} +func (t udpConn) Close() error { return nil } +func (t udpConn) LocalAddr() net.Addr { return t.conn.LocalAddr() } +func (t udpConn) RemoteAddr() net.Addr { return t.remoteAdd } +func (t udpConn) SetDeadline(b time.Time) error { return t.conn.SetDeadline(b) } +func (t udpConn) SetReadDeadline(b time.Time) error { return t.conn.SetReadDeadline(b) } +func (t udpConn) SetWriteDeadline(b time.Time) error { return t.conn.SetWriteDeadline(b) } + +type udpLis struct { + addr net.Addr + conn *net.UDPConn + pkgSize int +} - targetAddr, err := net.ResolveUDPAddr(network, targetaddr) +func NewUdpListener(network, listenaddr string, pkgSize ...int) (*udpLis, error) { + udpAddr, err := net.ResolveUDPAddr(network, listenaddr) if err != nil { - select { - default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: - } - return + return nil, err } - - serConn, err := lisAddr(network, listenaddr) - if err != nil { - select { - default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: - } - return + pkgSize = append(pkgSize, 1500) + if conn, err := net.ListenUDP(network, udpAddr); err != nil { + return nil, err + } else { + return &udpLis{ + addr: udpAddr, + conn: conn, + pkgSize: pkgSize[0], + }, nil } - - closec := make(chan struct{}) - //初始化关闭方法 - closef = func() { - serConn.Close() - close(closec) +} +func (t udpLis) Accept() (net.Conn, error) { + buf := make([]byte, t.pkgSize) + if n, remoteAdd, e := t.conn.ReadFromUDP(buf); e != nil { + return nil, e + } else if n == t.pkgSize { + return nil, ErrUdpOverflow + } else { + return NewUdpConn(buf[:n], remoteAdd, t.conn), nil } +} +func (t udpLis) Close() error { + return t.conn.Close() +} +func (t udpLis) Addr() net.Addr { + return t.addr +} - //返回监听端口 - select { - default: - case msg_chan <- ForwardMsg{ - Type: LisnMsg, - Msg: serConn.LocalAddr(), - }: - } +type ConnPool struct { + connMap *psync.Map + pool *pool.Buf[ConnPoolItem] +} - matchfunc := []func(ip net.IP) bool{} +type ConnPoolItem struct { + p *net.Conn + pooled bool +} - for _, cidr := range acceptCIDRs { - if _, cidrx, err := net.ParseCIDR(cidr); err != nil { - select { - default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: - } - return - } else { - matchfunc = append(matchfunc, cidrx.Contains) - } +func NewConnPool(size int, genConn func() *net.Conn) *ConnPool { + return &ConnPool{ + connMap: new(psync.Map), + pool: pool.New(pool.PoolFunc[ConnPoolItem]{ + New: func() *ConnPoolItem { + return &ConnPoolItem{p: genConn()} + }, + InUse: func(u *ConnPoolItem) bool { + return !u.pooled + }, + Reuse: func(u *ConnPoolItem) *ConnPoolItem { + u.pooled = false + return u + }, + Pool: func(u *ConnPoolItem) *ConnPoolItem { + u.pooled = true + return u + }, + }, size), } +} - //开始准备转发 - go func(serConn *net.UDPConn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) { - defer close(msg_chan) - defer serConn.Close() - - type UDPConn struct { - p *net.UDPConn - pooled bool - } - - var ( - connMap = make(map[string]*UDPConn) - udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{ - New: func() *UDPConn { - conn, _ := net.ListenUDP(network, nil) - return &UDPConn{p: conn} - }, - InUse: func(u *UDPConn) bool { - return !u.pooled - }, - Reuse: func(u *UDPConn) *UDPConn { - u.pooled = false - return u - }, - Pool: func(u *UDPConn) *UDPConn { - u.pooled = true - return u - }, - }, 100) - ) - genConn := func(cliAddr *net.UDPAddr) *UDPConn { - if conn, ok := connMap[cliAddr.String()]; ok { - return conn - } else { - conn = udpPool.Get() - connMap[cliAddr.String()] = conn - return conn - } - } - - buf := make([]byte, 20480) - for { - n, cliAddr, err := serConn.ReadFromUDP(buf) - if err != nil { - //返回Accept错误 - select { - default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: - } - continue - } - - var deny bool - for i := 0; !deny && i < len(matchfunc); i++ { - deny = deny && !matchfunc[i](cliAddr.IP) - } - if deny { - //返回Deny - select { - default: - case msg_chan <- ForwardMsg{ - Type: DenyMsg, - Msg: net.Addr(cliAddr), - }: - } - continue - } - - //返回Accept - select { - default: - case msg_chan <- ForwardMsg{ - Type: AcceptMsg, - Msg: net.Addr(cliAddr), - }: - } - - select { - case <-closec: - default: - break - } - - targetConn := genConn(cliAddr) - if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil { - //返回Accept错误 - select { - default: - case msg_chan <- ForwardMsg{ - Type: ErrorMsg, - Msg: err, - }: - } - } else { - go func() { - defer udpPool.Put(targetConn) - - buf := make([]byte, 20480) - for { - targetConn.p.SetDeadline(time.Now().Add(time.Second * 20)) - n, _, e := targetConn.p.ReadFromUDP(buf) - if e != nil { - return - } - if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil { - return - } - } - }() - } +func (t *ConnPool) Get(id string) (i net.Conn, putBack func()) { + if conn, ok := t.connMap.LoadV(id).(*ConnPoolItem); ok { + return *conn.p, func() {} + } else { + conn = t.pool.Get() + t.connMap.Store(id, conn) + return *conn.p, func() { + t.connMap.Delete(id) + t.pool.Put(conn) } - }(serConn, targetAddr, msg_chan) - - return + } } +// func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) { +// //初始化消息通道 +// msg_chan = make(chan ForwardMsg, 1000) + +// targetAddr, err := net.ResolveUDPAddr(network, targetaddr) +// if err != nil { +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: ErrorMsg, +// Msg: err, +// }: +// } +// return +// } + +// serConn, err := udpListener(network, listenaddr) +// if err != nil { +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: ErrorMsg, +// Msg: err, +// }: +// } +// return +// } + +// closec := make(chan struct{}) +// //初始化关闭方法 +// closef = func() { +// serConn.Close() +// close(closec) +// } + +// //返回监听端口 +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: LisnMsg, +// Msg: serConn.LocalAddr(), +// }: +// } + +// matchfunc := []func(ip net.IP) bool{} + +// for _, cidr := range acceptCIDRs { +// if _, cidrx, err := net.ParseCIDR(cidr); err != nil { +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: ErrorMsg, +// Msg: err, +// }: +// } +// return +// } else { +// matchfunc = append(matchfunc, cidrx.Contains) +// } +// } + +// //开始准备转发 +// go func(serConn net.Conn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) { +// defer close(msg_chan) +// defer serConn.Close() + +// type UDPConn struct { +// p *net.UDPConn +// pooled bool +// } + +// var ( +// connMap = make(map[string]*UDPConn) +// udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{ +// New: func() *UDPConn { +// conn, _ := net.ListenUDP(network, nil) +// return &UDPConn{p: conn} +// }, +// InUse: func(u *UDPConn) bool { +// return !u.pooled +// }, +// Reuse: func(u *UDPConn) *UDPConn { +// u.pooled = false +// return u +// }, +// Pool: func(u *UDPConn) *UDPConn { +// u.pooled = true +// return u +// }, +// }, 100) +// ) +// genConn := func(cliAddr *net.UDPAddr) *UDPConn { +// if conn, ok := connMap[cliAddr.String()]; ok { +// return conn +// } else { +// conn = udpPool.Get() +// connMap[cliAddr.String()] = conn +// return conn +// } +// } + +// buf := make([]byte, 20480) +// for { +// n, err := serConn.Read(buf) +// if err != nil { +// //返回Accept错误 +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: ErrorMsg, +// Msg: err, +// }: +// } +// continue +// } + +// ip := net.ParseIP(strings.Split(serConn.RemoteAddr().String(), ":")[0]) +// var deny bool +// for i := 0; !deny && i < len(matchfunc); i++ { +// deny = deny && !matchfunc[i](ip) +// } +// if deny { +// //返回Deny +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: DenyMsg, +// Msg: serConn.RemoteAddr(), +// }: +// } +// serConn.Close() +// continue +// } + +// //返回Accept +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: AcceptMsg, +// Msg: serConn.RemoteAddr(), +// }: +// } + +// select { +// case <-closec: +// default: +// break +// } + +// targetConn := genConn(cliAddr) +// if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil { +// //返回Accept错误 +// select { +// default: +// case msg_chan <- ForwardMsg{ +// Type: ErrorMsg, +// Msg: err, +// }: +// } +// } else { +// go func() { +// defer udpPool.Put(targetConn) + +// buf := make([]byte, 20480) +// for { +// targetConn.p.SetDeadline(time.Now().Add(time.Second * 20)) +// n, _, e := targetConn.p.ReadFromUDP(buf) +// if e != nil { +// return +// } +// if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil { +// return +// } +// } +// }() +// } +// } +// }(serConn, targetAddr, msg_chan) + +// return +// } + func (this *netl) GetLocalDns() error { if runtime.GOOS == "windows" { cmd := exec.Command("nslookup", "127.0.0.1")