import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
"net"
"net/url"
+ "os"
"os/exec"
"runtime"
"strings"
- "sync"
"time"
+ "github.com/dustin/go-humanize"
"github.com/miekg/dns"
+ pe "github.com/qydysky/part/errors"
pfile "github.com/qydysky/part/file"
+ psync "github.com/qydysky/part/sync"
)
var (
ErrDnsNoAnswer = errors.New("ErrDnsNoAnswer")
ErrNetworkNoSupport = errors.New("ErrNetworkNoSupport")
- ErrUdpOverflow = errors.New("ErrUdpOverflow") // default:1500 set higher pkgSize at NewUdpListener
+ ErrUdpOverflow = errors.New("ErrUdpOverflow")
)
type netl struct {
return true
}
-const (
- ErrorMsg = iota
- WarnMsg
- AcceptMsg
- DenyMsg
- LisnMsg
- ClosMsg
-)
-
-// when Type is ErrorMsg, Msg is set to error
-// when Type is AcceptMsg, Msg is set to net.Addr
-// when Type is DenyMsg, Msg is set to net.Addr
-// when Type is LisnMsg, Msg is set to net.Addr
-type ForwardMsg struct {
- Type int
- Msg interface{}
-}
-
// 桥
func connBridge(a, b net.Conn, bufSize int) {
- fmt.Println(b.LocalAddr())
- var wg sync.WaitGroup
+ var wg = make(chan struct{}, 3)
- wg.Add(2)
- go func() {
- defer func() {
- fmt.Println("close r", b.LocalAddr())
- wg.Done()
- }()
-
- buf := make([]byte, bufSize)
+ buf := make([]byte, bufSize*2)
+ go func() {
+ buf := buf[:bufSize]
for {
if n, err := a.Read(buf); err != nil {
- a.Close()
- return
+ if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrDeadlineExceeded) && !errors.Is(err, net.ErrClosed) {
+ fmt.Println(err)
+ }
+ break
} else if _, err = b.Write(buf[:n]); err != nil {
- fmt.Println(err)
- return
+ if !errors.Is(err, io.EOF) {
+ fmt.Println(err)
+ }
+ break
}
}
+ wg <- struct{}{}
}()
go func() {
- defer func() {
- fmt.Println("close w", b.LocalAddr())
- wg.Done()
- }()
-
- buf := make([]byte, bufSize)
-
+ buf := buf[bufSize:]
for {
if n, err := b.Read(buf); err != nil {
- b.Close()
- return
+ if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrDeadlineExceeded) && !errors.Is(err, net.ErrClosed) {
+ fmt.Println(err)
+ }
+ break
} else if _, err = a.Write(buf[:n]); err != nil {
- fmt.Println(err)
- return
+ if !errors.Is(err, io.EOF) {
+ fmt.Println(err)
+ }
+ break
}
}
+ wg <- struct{}{}
}()
- wg.Wait()
- fmt.Println("close", b.LocalAddr())
+ <-wg
+ a.Close()
+ b.Close()
}
-func Forward(targetaddr, listenaddr string, acceptCIDRs []string) (closef func(), msg_chan chan ForwardMsg) {
- msg_chan = make(chan ForwardMsg, 1000)
+var (
+ ErrForwardAccept pe.Action = `ErrForwardAccept`
+ ErrForwardDail pe.Action = `ErrForwardDail`
+)
+
+type ForwardMsgFunc interface {
+ ErrorMsg(targetaddr, listenaddr string, e error)
+ WarnMsg(targetaddr, listenaddr string, e error)
+ AcceptMsg(remote net.Addr, targetaddr string) (ConFinMsg func())
+ DenyMsg(remote net.Addr, targetaddr string)
+ LisnMsg(targetaddr, listenaddr string)
+ ClosMsg(targetaddr, listenaddr string)
+}
+
+func Forward(targetaddr, listenaddr string, acceptCIDRs []string, callBack ForwardMsgFunc) (closef func()) {
closef = func() {}
lisNet := strings.Split(listenaddr, "://")[0]
lisAddr := strings.Split(listenaddr, "://")[1]
tarNet := strings.Split(targetaddr, "://")[0]
tarAddr := strings.Split(targetaddr, "://")[1]
+ lisIsUdp := strings.Contains(lisNet, "udp")
+ tarIsUdp := strings.Contains(tarNet, "udp")
+
+ if (!lisIsUdp && tarIsUdp) || (lisIsUdp && !tarIsUdp) {
+ callBack.ErrorMsg(targetaddr, listenaddr, ErrNetworkNoSupport)
+ return
+ }
//尝试监听
var listener net.Listener
err = ErrNetworkNoSupport
}
if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
+ callBack.ErrorMsg(targetaddr, listenaddr, err)
return
}
}
err = ErrNetworkNoSupport
}
if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
+ callBack.ErrorMsg(targetaddr, listenaddr, err)
return
}
}
- closec := make(chan struct{})
//初始化关闭方法
closef = func() {
listener.Close()
- close(closec)
+ callBack.ClosMsg(targetaddr, listenaddr)
}
//返回监听地址
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: LisnMsg,
- Msg: listener.Addr(),
- }:
- }
+ callBack.LisnMsg(targetaddr, listenaddr)
matchfunc := []func(ip net.IP) bool{}
for _, cidr := range acceptCIDRs {
if _, cidrx, err := net.ParseCIDR(cidr); err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: ErrorMsg,
- Msg: err,
- }:
- }
+ callBack.ErrorMsg(targetaddr, listenaddr, err)
return
} else {
matchfunc = append(matchfunc, cidrx.Contains)
}
//开始准备转发
- go func(listener net.Listener, msg_chan chan ForwardMsg) {
- defer close(msg_chan)
+ go func(listener net.Listener) {
defer listener.Close()
for {
proxyconn, err := listener.Accept()
+ if errors.Is(err, ErrUdpConnected) {
+ continue
+ }
if err != nil {
//返回Accept错误
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: WarnMsg,
- Msg: err,
- }:
- }
+ callBack.WarnMsg(targetaddr, listenaddr, pe.Join(ErrForwardAccept, err))
continue
}
host, _, err := net.SplitHostPort(proxyconn.RemoteAddr().String())
if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: WarnMsg,
- Msg: err,
- }:
- }
+ callBack.WarnMsg(targetaddr, listenaddr, err)
continue
}
}
if !accept {
//返回Deny
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: DenyMsg,
- Msg: proxyconn.RemoteAddr(),
- }:
- }
+ callBack.DenyMsg(proxyconn.RemoteAddr(), targetaddr)
proxyconn.Close()
continue
}
//返回Accept
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: AcceptMsg,
- Msg: proxyconn.RemoteAddr(),
- }:
- }
-
- select {
- case <-closec:
- default:
- break
- }
+ conFin := callBack.AcceptMsg(proxyconn.RemoteAddr(), targetaddr)
go func() {
+ defer conFin()
targetconn, err := net.Dial(tarNet, tarAddr)
if err != nil {
- select {
- default:
- case msg_chan <- ForwardMsg{
- Type: WarnMsg,
- Msg: err,
- }:
- }
+ callBack.WarnMsg(targetaddr, listenaddr, pe.Join(ErrForwardDail, err))
return
}
- connBridge(proxyconn, targetconn, 20480)
-
- switch lisNet {
- case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
- case "udp", "udp4", "udp6":
- time.Sleep(time.Second * 10)
- default:
- }
+ // if !lisIsUdp && !tarIsUdp {
+ // tcp2tcpConnBridge(proxyconn, targetconn, 65535)
+ // } else if lisIsUdp && tarIsUdp {
+ connBridge(proxyconn, targetconn, 65535)
+ // } else {
+ // connBridge(proxyconn, targetconn, 65535)
+ // }
}()
}
- }(listener, msg_chan)
+ }(listener)
return
}
+var ErrUdpConnOverflow = errors.New(`ErrUdpConnOverflow`)
+
type udpConn struct {
+ e error
conn *net.UDPConn
remoteAdd *net.UDPAddr
- reader io.Reader
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ buf chan []byte
+ closef func() error
}
-func NewUdpConn(per []byte, remoteAdd *net.UDPAddr, conn *net.UDPConn) *udpConn {
- return &udpConn{
- conn: conn,
- remoteAdd: remoteAdd,
- reader: bytes.NewReader(per),
+func (t *udpConn) SetBuf(b []byte) {
+ tmp := make([]byte, len(b))
+ copy(tmp, b)
+ select {
+ case t.buf <- tmp:
+ default:
+ t.e = ErrUdpConnOverflow
+ }
+}
+func (t *udpConn) Read(b []byte) (n int, err error) {
+ select {
+ case tmp := <-t.buf:
+ n = copy(b, tmp)
+ case <-t.ctx.Done():
+ err = os.ErrDeadlineExceeded
+ }
+ return
+}
+func (t *udpConn) Write(b []byte) (n int, err error) {
+ select {
+ case <-t.ctx.Done():
+ err = os.ErrDeadlineExceeded
+ default:
+ n, err = t.conn.WriteToUDP(b, t.remoteAdd)
+ if err != nil {
+ t.ctx.Done()
+ }
}
+ return
}
-func (t udpConn) Read(b []byte) (n int, err error) {
- return t.reader.Read(b)
+func (t *udpConn) Close() error {
+ t.ctxCancel()
+ return t.closef()
}
-func (t udpConn) Write(b []byte) (n int, err error) {
- return t.conn.WriteToUDP(b, t.remoteAdd)
+func (t *udpConn) LocalAddr() net.Addr { return t.conn.LocalAddr() }
+func (t *udpConn) RemoteAddr() net.Addr { return t.remoteAdd }
+func (t *udpConn) SetDeadline(b time.Time) error {
+ time.AfterFunc(time.Until(b), func() {
+ t.Close()
+ })
+ return nil
}
-func (t udpConn) Close() error { return nil }
-func (t udpConn) LocalAddr() net.Addr { return t.conn.LocalAddr() }
-func (t udpConn) RemoteAddr() net.Addr { return t.remoteAdd }
-func (t udpConn) SetDeadline(b time.Time) error { return t.conn.SetDeadline(b) }
-func (t udpConn) SetReadDeadline(b time.Time) error { return t.conn.SetReadDeadline(b) }
-func (t udpConn) SetWriteDeadline(b time.Time) error { return t.conn.SetWriteDeadline(b) }
+func (t *udpConn) SetReadDeadline(b time.Time) error { return t.SetDeadline(b) }
+func (t *udpConn) SetWriteDeadline(b time.Time) error { return t.SetDeadline(b) }
type udpLis struct {
- addr net.Addr
- conn *net.UDPConn
- pkgSize int
+ udpAddr *net.UDPAddr
+ c <-chan *udpConn
+ closef func() error
}
-func NewUdpListener(network, listenaddr string, pkgSize ...int) (*udpLis, error) {
+var ErrUdpConnected error = errors.New("ErrUdpConnected")
+
+func NewUdpListener(network, listenaddr string) (*udpLis, error) {
udpAddr, err := net.ResolveUDPAddr(network, listenaddr)
if err != nil {
return nil, err
}
- pkgSize = append(pkgSize, 1500)
if conn, err := net.ListenUDP(network, udpAddr); err != nil {
return nil, err
} else {
- return &udpLis{
- addr: udpAddr,
- conn: conn,
- pkgSize: pkgSize[0],
- }, nil
+ c := make(chan *udpConn, 10)
+ lis := &udpLis{
+ udpAddr: udpAddr,
+ c: c,
+ closef: func() error {
+ return conn.Close()
+ },
+ }
+ go func() {
+ var link psync.MapG[string, *udpConn]
+ buf := make([]byte, humanize.MByte)
+ for {
+ n, remoteAdd, e := conn.ReadFromUDP(buf)
+ if e != nil {
+ c <- &udpConn{e: e}
+ return
+ }
+ if udpc, ok := link.Load(remoteAdd.String()); ok {
+ udpc.SetBuf(buf[:n])
+ c <- &udpConn{e: ErrUdpConnected}
+ } else {
+ udpc := &udpConn{
+ conn: conn,
+ remoteAdd: remoteAdd,
+ closef: func() error {
+ link.Delete(remoteAdd.String())
+ return nil
+ },
+ buf: make(chan []byte, 5),
+ }
+ udpc.ctx, udpc.ctxCancel = context.WithTimeout(context.Background(), time.Second*30)
+ udpc.SetBuf(buf[:n])
+ link.Store(remoteAdd.String(), udpc)
+ c <- udpc
+ }
+ }
+ }()
+ return lis, nil
}
}
-func (t udpLis) Accept() (net.Conn, error) {
- buf := make([]byte, t.pkgSize)
- if n, remoteAdd, e := t.conn.ReadFromUDP(buf); e != nil {
- return nil, e
- } else if n == t.pkgSize {
- return nil, ErrUdpOverflow
- } else {
- return NewUdpConn(buf[:n], remoteAdd, t.conn), nil
- }
+func (t *udpLis) Accept() (net.Conn, error) {
+ udpc := <-t.c
+ return udpc, udpc.e
}
-func (t udpLis) Close() error {
- return t.conn.Close()
+func (t *udpLis) Close() error {
+ return t.closef()
}
-func (t udpLis) Addr() net.Addr {
- return t.addr
+func (t *udpLis) Addr() net.Addr {
+ return t.udpAddr
}
// type ConnPool struct {