"time"
"github.com/miekg/dns"
+ pool "github.com/qydysky/part/pool"
)
type netl struct {
Msg interface{}
}
+// 桥
+func connBridge(a, b net.Conn) {
+ fin := make(chan bool, 1)
+ var wg sync.WaitGroup
+
+ wg.Add(2)
+ go func() {
+ defer func() {
+ a.Close()
+ b.Close()
+ fin <- true
+ wg.Done()
+ }()
+
+ buf := make([]byte, 20480)
+
+ for {
+ select {
+ case <-fin:
+ return
+ default:
+ n, err := a.Read(buf)
+
+ if err != nil {
+ return
+ }
+ b.Write(buf[:n])
+ }
+ }
+ }()
+
+ go func() {
+ defer func() {
+ a.Close()
+ b.Close()
+ fin <- true
+ wg.Done()
+ }()
+
+ buf := make([]byte, 20480)
+
+ for {
+ select {
+ case <-fin:
+ return
+ default:
+ n, err := b.Read(buf)
+
+ if err != nil {
+ return
+ }
+ a.Write(buf[:n])
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func Forward(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
//初始化消息通道
msg_chan = make(chan ForwardMsg, 1000)
defer close(msg_chan)
defer listener.Close()
- //tcp 桥
- tcpBridge2 := func(a, b net.Conn) {
- fin := make(chan bool, 1)
- var wg sync.WaitGroup
-
- wg.Add(2)
- go func() {
- defer func() {
- a.Close()
- b.Close()
- fin <- true
- wg.Done()
- }()
-
- buf := make([]byte, 20480)
-
- for {
- select {
- case <-fin:
- return
- default:
- n, err := a.Read(buf)
-
- if err != nil {
- return
- }
- b.Write(buf[:n])
- }
- }
- }()
-
- go func() {
- defer func() {
- a.Close()
- b.Close()
- fin <- true
- wg.Done()
- }()
-
- buf := make([]byte, 20480)
-
- for {
- select {
- case <-fin:
- return
- default:
- n, err := b.Read(buf)
-
- if err != nil {
- return
- }
- a.Write(buf[:n])
- }
- }
- }()
-
- wg.Wait()
- }
-
for {
proxyconn, err := listener.Accept()
if err != nil {
continue
}
- go tcpBridge2(proxyconn, targetconn)
+ go connBridge(proxyconn, targetconn)
}
}(listener, targetaddr, network, msg_chan)
return
}
+func ForwardUdp(targetaddr, network, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
+ //初始化消息通道
+ msg_chan = make(chan ForwardMsg, 1000)
+
+ lisAddr := func(network, listenaddr string) (*net.UDPConn, error) {
+ if udpAddr, err := net.ResolveUDPAddr(network, listenaddr); err != nil {
+ return nil, err
+ } else if conn, err := net.ListenUDP(network, udpAddr); err != nil {
+ return nil, err
+ } else {
+ return conn, nil
+ }
+ }
+
+ targetAddr, err := net.ResolveUDPAddr(network, targetaddr)
+ if err != nil {
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ return
+ }
+
+ serConn, err := lisAddr(network, listenaddr)
+ if err != nil {
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ return
+ }
+
+ closec := make(chan struct{})
+ //初始化关闭方法
+ closef = func() {
+ serConn.Close()
+ close(closec)
+ }
+
+ //返回监听端口
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: PortMsg,
+ Msg: serConn.LocalAddr().(*net.UDPAddr).Port,
+ }:
+ }
+
+ matchfunc := []func(ip net.IP) bool{}
+
+ for _, cidr := range acceptCIDRs {
+ if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ return
+ } else {
+ matchfunc = append(matchfunc, cidrx.Contains)
+ }
+ }
+
+ //开始准备转发
+ go func(serConn *net.UDPConn, targetAddr *net.UDPAddr, msg_chan chan ForwardMsg) {
+ defer close(msg_chan)
+ defer serConn.Close()
+
+ type UDPConn struct {
+ p *net.UDPConn
+ pooled bool
+ }
+
+ var (
+ connMap = make(map[string]*UDPConn)
+ udpPool = pool.New[UDPConn](pool.PoolFunc[UDPConn]{
+ New: func() *UDPConn {
+ conn, _ := net.ListenUDP(network, nil)
+ return &UDPConn{p: conn}
+ },
+ InUse: func(u *UDPConn) bool {
+ return !u.pooled
+ },
+ Reuse: func(u *UDPConn) *UDPConn {
+ u.pooled = false
+ return u
+ },
+ Pool: func(u *UDPConn) *UDPConn {
+ u.pooled = true
+ return u
+ },
+ }, 100)
+ )
+ genConn := func(cliAddr *net.UDPAddr) *UDPConn {
+ if conn, ok := connMap[cliAddr.String()]; ok {
+ return conn
+ } else {
+ conn = udpPool.Get()
+ connMap[cliAddr.String()] = conn
+ return conn
+ }
+ }
+
+ buf := make([]byte, 20480)
+ for {
+ n, cliAddr, err := serConn.ReadFromUDP(buf)
+ if err != nil {
+ //返回Accept错误
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ continue
+ }
+
+ var accpet bool
+ for i := 0; i < len(matchfunc); i++ {
+ accpet = accpet || matchfunc[i](cliAddr.IP)
+ }
+ if !accpet {
+ //返回Deny
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: DenyMsg,
+ Msg: net.Addr(cliAddr),
+ }:
+ }
+ continue
+ }
+
+ //返回Accept
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: AcceptMsg,
+ Msg: net.Addr(cliAddr),
+ }:
+ }
+
+ select {
+ case <-closec:
+ default:
+ break
+ }
+
+ targetConn := genConn(cliAddr)
+ if _, err := targetConn.p.WriteToUDP(buf[:n], targetAddr); err != nil {
+ //返回Accept错误
+ select {
+ default:
+ case msg_chan <- ForwardMsg{
+ Type: ErrorMsg,
+ Msg: err,
+ }:
+ }
+ } else {
+ go func() {
+ defer udpPool.Put(targetConn)
+
+ buf := make([]byte, 20480)
+ for {
+ targetConn.p.SetDeadline(time.Now().Add(time.Second * 20))
+ n, _, e := targetConn.p.ReadFromUDP(buf)
+ if e != nil {
+ return
+ }
+ if n, e = targetConn.p.WriteToUDP(buf[:n], cliAddr); err != nil {
+ return
+ }
+ }
+ }()
+ }
+ }
+ }(serConn, targetAddr, msg_chan)
+
+ return
+}
+
func (this *netl) GetLocalDns() error {
if runtime.GOOS == "windows" {
cmd := exec.Command("nslookup", "127.0.0.1")