import (
"bytes"
"errors"
+ "io"
"net"
"net/url"
"os/exec"
"github.com/miekg/dns"
pool "github.com/qydysky/part/pool"
+ psync "github.com/qydysky/part/sync"
+)
+
+var (
+ ErrDnsNoAnswer = errors.New("ErrDnsNoAnswer")
+ ErrNetworkNoSupport = errors.New("ErrNetworkNoSupport")
+ ErrUdpOverflow = errors.New("ErrUdpOverflow") // default:1500 set higher pkgSize at NewUdpListener
)
type netl struct {
return err
}
if len(r.Answer) == 0 {
- return errors.New("no answer")
+ return ErrDnsNoAnswer
}
this.RV = append(this.RV, dns.Field(r.Answer[0], 1))
wg.Add(2)
go func() {
defer func() {
- a.Close()
- b.Close()
fin <- true
wg.Done()
}()
case <-fin:
return
default:
- n, err := a.Read(buf)
-
- if err != nil {
+ if n, err := a.Read(buf); err != nil {
+ return
+ } else if _, err = b.Write(buf[:n]); err != nil {
return
}
- b.Write(buf[:n])
}
}
}()
go func() {
defer func() {
- a.Close()
- b.Close()
fin <- true
wg.Done()
}()
case <-fin:
return
default:
- n, err := b.Read(buf)
-
- if err != nil {
+ if n, err := b.Read(buf); err != nil {
+ return
+ } else if _, err = a.Write(buf[:n]); err != nil {
return
}
- a.Write(buf[:n])
}
}
}()
wg.Wait()
}
-// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
-func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
- //初始化消息通道
+func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
msg_chan = make(chan ForwardMsg, 1000)
+ lisNet := strings.Split(listenaddr, "://")[0]
+ lisAddr := strings.Split(listenaddr, "://")[1]
+ tarNet := strings.Split(targetaddr, "://")[0]
+ tarAddr := strings.Split(targetaddr, "://")[1]
+
//尝试监听
- listener, err := net.Listen(network, listenaddr)
- if err != nil {
- select {
+ var listener net.Listener
+ {
+ var err error
+ switch lisNet {
+ case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
+ listener, err = net.Listen(lisNet, lisAddr)
+ case "udp", "udp4", "udp6":
+ listener, err = NewUdpListener(lisNet, lisAddr)
default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
+ err = ErrNetworkNoSupport
+ }
+ if err != nil {
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ return
+ }
+ }
+ {
+ var err error
+ switch lisNet {
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6", "unix", "unixgram", "unixpacket":
+ default:
+ err = ErrNetworkNoSupport
+ }
+ if err != nil {
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ return
}
- return
}
closec := make(chan struct{})
}
matchfunc := []func(ip net.IP) bool{}
-
for _, cidr := range acceptCIDRs {
if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
select {
}
//开始准备转发
- go func(listener net.Listener, targetaddr, network string, msg_chan chan ForwardMsg) {
+ go func(listener net.Listener, msg_chan chan ForwardMsg) {
defer close(msg_chan)
defer listener.Close()
+ tarConnPool := NewConnPool(100, func() *net.Conn {
+ conn, _ := net.Dial(tarNet, tarAddr)
+ return &conn
+ })
for {
proxyconn, err := listener.Accept()
if err != nil {
var deny bool
for i := 0; !deny && i < len(matchfunc); i++ {
- deny = deny && !matchfunc[i](ip)
+ deny = deny || !matchfunc[i](ip)
}
if deny {
//返回Deny
break
}
- targetconn, err := net.Dial(network, targetaddr)
- if err != nil {
- select {
+ go func() {
+ targetconn, putBackf := tarConnPool.Get(tarAddr)
+ defer putBackf()
+
+ connBridge(proxyconn, targetconn)
+
+ switch lisNet {
+ case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
+ case "udp", "udp4", "udp6":
+ time.Sleep(time.Second * 10)
default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
}
proxyconn.Close()
- continue
- }
-
- go connBridge(proxyconn, targetconn)
+ targetconn.Close()
+ }()
}
- }(listener, targetaddr, network, msg_chan)
+ }(listener, msg_chan)
return
}
-func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
- //初始化消息通道
- msg_chan = make(chan ForwardMsg, 1000)
+type udpConn struct {
+ conn *net.UDPConn
+ remoteAdd *net.UDPAddr
+ reader io.Reader
+}
- lisAddr := func(network, listenaddr string) (*net.UDPConn, error) {
- if udpAddr, err := net.ResolveUDPAddr(network, listenaddr); err != nil {
- return nil, err
- } else if conn, err := net.ListenUDP(network, udpAddr); err != nil {
- return nil, err
- } else {
- return conn, nil
- }
+func NewUdpConn(per []byte, remoteAdd *net.UDPAddr, conn *net.UDPConn) *udpConn {
+ return &udpConn{
+ conn: conn,
+ remoteAdd: remoteAdd,
+ reader: io.MultiReader(bytes.NewReader(per), conn),
}
+}
+func (t udpConn) Read(b []byte) (n int, err error) {
+ return t.reader.Read(b)
+}
+func (t udpConn) Write(b []byte) (n int, err error) {
+ return t.conn.WriteToUDP(b, t.remoteAdd)
+}
+func (t udpConn) Close() error { return nil }
+func (t udpConn) LocalAddr() net.Addr { return t.conn.LocalAddr() }
+func (t udpConn) RemoteAddr() net.Addr { return t.remoteAdd }
+func (t udpConn) SetDeadline(b time.Time) error { return t.conn.SetDeadline(b) }
+func (t udpConn) SetReadDeadline(b time.Time) error { return t.conn.SetReadDeadline(b) }
+func (t udpConn) SetWriteDeadline(b time.Time) error { return t.conn.SetWriteDeadline(b) }
+
+type udpLis struct {
+ addr net.Addr
+ conn *net.UDPConn
+ pkgSize int
+}
- targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+func NewUdpListener(network, listenaddr string, pkgSize ...int) (*udpLis, error) {
+ udpAddr, err := net.ResolveUDPAddr(network, listenaddr)
if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
- return
+ return nil, err
}
-
- serConn, err := lisAddr(network, listenaddr)
- if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
- return
+ pkgSize = append(pkgSize, 1500)
+ if conn, err := net.ListenUDP(network, udpAddr); err != nil {
+ return nil, err
+ } else {
+ return &udpLis{
+ addr: udpAddr,
+ conn: conn,
+ pkgSize: pkgSize[0],
+ }, nil
}
-
- closec := make(chan struct{})
- //初始化关闭方法
- closef = func() {
- serConn.Close()
- close(closec)
+}
+func (t udpLis) Accept() (net.Conn, error) {
+ buf := make([]byte, t.pkgSize)
+ if n, remoteAdd, e := t.conn.ReadFromUDP(buf); e != nil {
+ return nil, e
+ } else if n == t.pkgSize {
+ return nil, ErrUdpOverflow
+ } else {
+ return NewUdpConn(buf[:n], remoteAdd, t.conn), nil
}
+}
+func (t udpLis) Close() error {
+ return t.conn.Close()
+}
+func (t udpLis) Addr() net.Addr {
+ return t.addr
+}
- //返回监听端口
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: LisnMsg,
- Msg: serConn.LocalAddr(),
- }:
- }
+type ConnPool struct {
+ connMap *psync.Map
+ pool *pool.Buf[ConnPoolItem]
+}
- matchfunc := []func(ip net.IP) bool{}
+type ConnPoolItem struct {
+ p *net.Conn
+ pooled bool
+}
- for _, cidr := range acceptCIDRs {
- if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
- return
- } else {
- matchfunc = append(matchfunc, cidrx.Contains)
- }
+func NewConnPool(size int, genConn func() *net.Conn) *ConnPool {
+ return &ConnPool{
+ connMap: new(psync.Map),
+ pool: pool.New(pool.PoolFunc[ConnPoolItem]{
+ New: func() *ConnPoolItem {
+ return &ConnPoolItem{p: genConn()}
+ },
+ InUse: func(u *ConnPoolItem) bool {
+ return !u.pooled
+ },
+ Reuse: func(u *ConnPoolItem) *ConnPoolItem {
+ u.pooled = false
+ return u
+ },
+ Pool: func(u *ConnPoolItem) *ConnPoolItem {
+ u.pooled = true
+ return u
+ },
+ }, size),
}
+}
- //开始准备转发
- go func(serConn *net.UDPConn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
- defer close(msg_chan)
- defer serConn.Close()
-
- type UDPConn struct {
- p *net.UDPConn
- pooled bool
- }
-
- var (
- connMap = make(map[string]*UDPConn)
- udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
- New: func() *UDPConn {
- conn, _ := net.ListenUDP(network, nil)
- return &UDPConn{p: conn}
- },
- InUse: func(u *UDPConn) bool {
- return !u.pooled
- },
- Reuse: func(u *UDPConn) *UDPConn {
- u.pooled = false
- return u
- },
- Pool: func(u *UDPConn) *UDPConn {
- u.pooled = true
- return u
- },
- }, 100)
- )
- genConn := func(cliAddr *net.UDPAddr) *UDPConn {
- if conn, ok := connMap[cliAddr.String()]; ok {
- return conn
- } else {
- conn = udpPool.Get()
- connMap[cliAddr.String()] = conn
- return conn
- }
- }
-
- buf := make([]byte, 20480)
- for {
- n, cliAddr, err := serConn.ReadFromUDP(buf)
- if err != nil {
- //返回Accept错误
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
- continue
- }
-
- var deny bool
- for i := 0; !deny && i < len(matchfunc); i++ {
- deny = deny && !matchfunc[i](cliAddr.IP)
- }
- if deny {
- //返回Deny
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: DenyMsg,
- Msg: net.Addr(cliAddr),
- }:
- }
- continue
- }
-
- //返回Accept
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: AcceptMsg,
- Msg: net.Addr(cliAddr),
- }:
- }
-
- select {
- case <-closec:
- default:
- break
- }
-
- targetConn := genConn(cliAddr)
- if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
- //返回Accept错误
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
- } else {
- go func() {
- defer udpPool.Put(targetConn)
-
- buf := make([]byte, 20480)
- for {
- targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
- n, _, e := targetConn.p.ReadFromUDP(buf)
- if e != nil {
- return
- }
- if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
- return
- }
- }
- }()
- }
+func (t *ConnPool) Get(id string) (i net.Conn, putBack func()) {
+ if conn, ok := t.connMap.LoadV(id).(*ConnPoolItem); ok {
+ return *conn.p, func() {}
+ } else {
+ conn = t.pool.Get()
+ t.connMap.Store(id, conn)
+ return *conn.p, func() {
+ t.connMap.Delete(id)
+ t.pool.Put(conn)
}
- }(serConn, targetAddr, msg_chan)
-
- return
+ }
}
+// func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
+// //初始化消息通道
+// msg_chan = make(chan ForwardMsg, 1000)
+
+// targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+// if err != nil {
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: ErrorMsg,
+// Msg: err,
+// }:
+// }
+// return
+// }
+
+// serConn, err := udpListener(network, listenaddr)
+// if err != nil {
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: ErrorMsg,
+// Msg: err,
+// }:
+// }
+// return
+// }
+
+// closec := make(chan struct{})
+// //初始化关闭方法
+// closef = func() {
+// serConn.Close()
+// close(closec)
+// }
+
+// //返回监听端口
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: LisnMsg,
+// Msg: serConn.LocalAddr(),
+// }:
+// }
+
+// matchfunc := []func(ip net.IP) bool{}
+
+// for _, cidr := range acceptCIDRs {
+// if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: ErrorMsg,
+// Msg: err,
+// }:
+// }
+// return
+// } else {
+// matchfunc = append(matchfunc, cidrx.Contains)
+// }
+// }
+
+// //开始准备转发
+// go func(serConn net.Conn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
+// defer close(msg_chan)
+// defer serConn.Close()
+
+// type UDPConn struct {
+// p *net.UDPConn
+// pooled bool
+// }
+
+// var (
+// connMap = make(map[string]*UDPConn)
+// udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
+// New: func() *UDPConn {
+// conn, _ := net.ListenUDP(network, nil)
+// return &UDPConn{p: conn}
+// },
+// InUse: func(u *UDPConn) bool {
+// return !u.pooled
+// },
+// Reuse: func(u *UDPConn) *UDPConn {
+// u.pooled = false
+// return u
+// },
+// Pool: func(u *UDPConn) *UDPConn {
+// u.pooled = true
+// return u
+// },
+// }, 100)
+// )
+// genConn := func(cliAddr *net.UDPAddr) *UDPConn {
+// if conn, ok := connMap[cliAddr.String()]; ok {
+// return conn
+// } else {
+// conn = udpPool.Get()
+// connMap[cliAddr.String()] = conn
+// return conn
+// }
+// }
+
+// buf := make([]byte, 20480)
+// for {
+// n, err := serConn.Read(buf)
+// if err != nil {
+// //返回Accept错误
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: ErrorMsg,
+// Msg: err,
+// }:
+// }
+// continue
+// }
+
+// ip := net.ParseIP(strings.Split(serConn.RemoteAddr().String(), ":")[0])
+// var deny bool
+// for i := 0; !deny && i < len(matchfunc); i++ {
+// deny = deny && !matchfunc[i](ip)
+// }
+// if deny {
+// //返回Deny
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: DenyMsg,
+// Msg: serConn.RemoteAddr(),
+// }:
+// }
+// serConn.Close()
+// continue
+// }
+
+// //返回Accept
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: AcceptMsg,
+// Msg: serConn.RemoteAddr(),
+// }:
+// }
+
+// select {
+// case <-closec:
+// default:
+// break
+// }
+
+// targetConn := genConn(cliAddr)
+// if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
+// //返回Accept错误
+// select {
+// default:
+// case msg_chan <- ForwardMsg{
+// Type: ErrorMsg,
+// Msg: err,
+// }:
+// }
+// } else {
+// go func() {
+// defer udpPool.Put(targetConn)
+
+// buf := make([]byte, 20480)
+// for {
+// targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
+// n, _, e := targetConn.p.ReadFromUDP(buf)
+// if e != nil {
+// return
+// }
+// if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
+// return
+// }
+// }
+// }()
+// }
+// }
+// }(serConn, targetAddr, msg_chan)
+
+// return
+// }
+
func (this *netl) GetLocalDns() error {
if runtime.GOOS == "windows" {
cmd := exec.Command("nslookup", "127.0.0.1")