]> 127.0.0.1 Git - part/.git/commitdiff
Fix reqf v0.23.1
authorqydysky <32743305+qydysky@users.noreply.github.com>
Wed, 1 Mar 2023 18:32:56 +0000 (02:32 +0800)
committerqydysky <32743305+qydysky@users.noreply.github.com>
Wed, 1 Mar 2023 18:32:56 +0000 (02:32 +0800)
reqf/Reqf.go
reqf/Reqf_test.go

index 639db16d67a3329822d307498a4159694c73f8ce..44f809480a422aee6197618c0dc03e21e024c9b1 100644 (file)
@@ -11,6 +11,7 @@ import (
        "os"
        "strconv"
        "strings"
+       "sync"
        "time"
 
        flate "compress/flate"
@@ -48,61 +49,59 @@ type Req struct {
        Response *http.Response
        UsedTime time.Duration
 
-       cancel     *signal.Signal
-       running    *signal.Signal
-       responBuf  *bytes.Buffer
+       cancelF func()
+       cancel  *signal.Signal
+       running *signal.Signal
+
        responFile *os.File
-       asyncErr   error
+       err        error
+
+       l sync.Mutex
 }
 
 func New() *Req {
        return new(Req)
 }
 
-// func main(){
-//     var _ReqfVal = ReqfVal{
-//         Url:url,
-//         Proxy:proxy,
-//             Timeout:10,
-//             Retry:2,
-//     }
-//     Reqf(_ReqfVal)
-// }
-
 func (t *Req) Reqf(val Rval) error {
-
-       if val.SaveToChan != nil && len(val.SaveToChan) == 1 && !val.Async {
-               panic("must make sure chan size larger then 1 or use Async true")
-       }
-       if val.SaveToPipeWriter != nil && !val.Async {
-               panic("SaveToPipeWriter must use Async true")
-       }
+       t.l.Lock()
 
        t.Respon = []byte{}
        t.Response = nil
        t.UsedTime = 0
+       t.cancelF = nil
        t.cancel = signal.Init()
        t.running = signal.Init()
+       t.responFile = nil
+       t.err = nil
 
-       var returnErr error
-
-       _val := val
-
-       for SleepTime, Retry := _val.SleepTime, _val.Retry; Retry >= 0; Retry -= 1 {
-               returnErr = t.Reqf_1(_val)
+       go func() {
                select {
-               case <-t.cancel.WaitC(): //cancel
-                       return returnErr
-               default:
-                       if returnErr == nil {
-                               return nil
+               case <-t.cancel.Chan:
+                       if t.cancelF != nil {
+                               t.cancelF()
                        }
+               case <-t.running.Chan:
                }
-               time.Sleep(time.Duration(SleepTime) * time.Millisecond)
-       }
+       }()
+       go func() {
+               beginTime := time.Now()
+               _val := val
+
+               for SleepTime, Retry := _val.SleepTime, _val.Retry; Retry >= 0; Retry -= 1 {
+                       t.err = t.Reqf_1(_val)
+                       if t.err == nil || IsCancel(t.err) {
+                               break
+                       }
+                       time.Sleep(time.Duration(SleepTime) * time.Millisecond)
+               }
+
+               t.UsedTime = time.Since(beginTime)
+               t.running.Done()
+       }()
 
-       if !val.Async || returnErr != nil {
-               t.asyncErr = returnErr
+       if !val.Async {
+               t.Wait()
                if val.SaveToChan != nil {
                        close(val.SaveToChan)
                }
@@ -112,13 +111,25 @@ func (t *Req) Reqf(val Rval) error {
                if val.SaveToPipeWriter != nil {
                        val.SaveToPipeWriter.Close()
                }
-               if t.responBuf != nil {
-                       t.Respon = t.responBuf.Bytes()
-               }
-               t.running.Done()
                t.cancel.Done()
+               t.l.Unlock()
+       } else {
+               go func() {
+                       t.Wait()
+                       if val.SaveToChan != nil {
+                               close(val.SaveToChan)
+                       }
+                       if t.responFile != nil {
+                               t.responFile.Close()
+                       }
+                       if val.SaveToPipeWriter != nil {
+                               val.SaveToPipeWriter.Close()
+                       }
+                       t.cancel.Done()
+                       t.l.Unlock()
+               }()
        }
-       return returnErr
+       return t.err
 }
 
 func (t *Req) Reqf_1(val Rval) (err error) {
@@ -126,8 +137,6 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                Header map[string]string = val.Header
        )
 
-       var beginTime time.Time = time.Now()
-
        var client http.Client
 
        if Header == nil {
@@ -166,14 +175,7 @@ func (t *Req) Reqf_1(val Rval) (err error) {
        if val.Timeout > 0 {
                cx, cancel = context.WithTimeout(cx, time.Duration(val.Timeout)*time.Millisecond)
        }
-
-       go func() {
-               select {
-               case <-t.cancel.WaitC():
-                       cancel()
-               case <-t.running.WaitC():
-               }
-       }()
+       t.cancelF = cancel
 
        req, e := http.NewRequest(Method, val.Url, body)
        if e != nil {
@@ -205,19 +207,23 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                req.Header.Set(k, v)
        }
 
-       resp, err := client.Do(req)
-       if v, ok := Header["Connection"]; ok && strings.ToLower(v) != "keep-alive" {
-               defer client.CloseIdleConnections()
+       if !t.cancel.Islive() {
+               err = context.Canceled
+               return
        }
 
-       if err != nil {
-               return err
+       resp, e := client.Do(req)
+
+       if e != nil {
+               err = e
+               return
+       }
+
+       if v, ok := Header["Connection"]; ok && strings.ToLower(v) != "keep-alive" {
+               defer client.CloseIdleConnections()
        }
 
        t.Response = resp
-       defer func() {
-               t.UsedTime = time.Since(beginTime)
-       }()
 
        if val.JustResponseCode {
                return
@@ -227,12 +233,14 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                err = errors.New(strconv.Itoa(resp.StatusCode))
        }
 
+       var responBuf *bytes.Buffer
        var ws []io.Writer
        if val.SaveToPath != "" {
-               t.responFile, err = os.Create(val.SaveToPath)
+               t.responFile, e = os.Create(val.SaveToPath)
                if err != nil {
                        t.responFile.Close()
-                       return err
+                       err = e
+                       return
                }
                ws = append(ws, t.responFile)
        }
@@ -240,12 +248,10 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                ws = append(ws, val.SaveToPipeWriter)
        }
        if !val.NoResponse {
-               if t.responBuf == nil {
-                       t.responBuf = new(bytes.Buffer)
-               } else {
-                       t.responBuf.Reset()
+               if responBuf == nil {
+                       responBuf = new(bytes.Buffer)
                }
-               ws = append(ws, t.responBuf)
+               ws = append(ws, responBuf)
        }
 
        w := io.MultiWriter(ws...)
@@ -266,59 +272,40 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                resReader = resp.Body
        }
 
-       go func() {
-               buf := make([]byte, 512)
-
-               for {
-                       if n, e := resReader.Read(buf); n != 0 {
-                               w.Write(buf[:n])
-                               select {
-                               case val.SaveToChan <- buf[:n]:
-                               default:
-                               }
-                       } else if e != nil {
-                               if !errors.Is(e, io.EOF) {
-                                       err = e
-                               }
-                               break
-                       }
+       buf := make([]byte, 512)
 
-                       if !t.cancel.Islive() {
-                               err = context.Canceled
-                               break
+       for {
+               if n, e := resReader.Read(buf); n != 0 {
+                       w.Write(buf[:n])
+                       select {
+                       case val.SaveToChan <- buf[:n]:
+                       default:
                        }
+               } else if e != nil {
+                       if !errors.Is(e, io.EOF) {
+                               err = e
+                       }
+                       break
                }
 
-               if val.Async {
-                       t.asyncErr = err
+               if !t.cancel.Islive() {
+                       err = context.Canceled
+                       break
                }
-               resp.Body.Close()
-               if val.SaveToChan != nil {
-                       close(val.SaveToChan)
-               }
-               if t.responFile != nil {
-                       t.responFile.Close()
-               }
-               if val.SaveToPipeWriter != nil {
-                       val.SaveToPipeWriter.Close()
-               }
-               if t.responBuf != nil {
-                       t.Respon = t.responBuf.Bytes()
-               }
-               t.running.Done()
-       }()
-       if !val.Async {
-               t.Wait()
        }
-       // if _, e := io.Copy(w, resp.Body); e != nil {
-       //      err = e
-       // }
+
+       resp.Body.Close()
+
+       if responBuf != nil {
+               t.Respon = responBuf.Bytes()
+       }
+
        return
 }
 
 func (t *Req) Wait() error {
        t.running.Wait()
-       return t.asyncErr
+       return t.err
 }
 
 func (t *Req) Cancel() { t.Close() }
index 7113b8b6b194a7a85c6d328d6be98d17c85df74d..6c5fe3b592c397bb559591ebf3cb97f9846c9547 100644 (file)
@@ -12,8 +12,9 @@ import (
        web "github.com/qydysky/part/web"
 )
 
-func Test_req(t *testing.T) {
-       addr := "127.0.0.1:10001"
+var addr = "127.0.0.1:10001"
+
+func init() {
        s := web.New(&http.Server{
                Addr:         addr,
                WriteTimeout: time.Second * time.Duration(10),
@@ -47,13 +48,27 @@ func Test_req(t *testing.T) {
                        s.Server.Shutdown(context.Background())
                },
        })
+}
+
+func Test_req7(t *testing.T) {
+       r := New()
+       r.Reqf(Rval{
+               Url:   "http://" + addr + "/to",
+               Async: true,
+       })
+       r.Cancel()
+       if !IsCancel(r.Wait()) {
+               t.Error("async Cancel fail")
+       }
+}
 
+func Test_req(t *testing.T) {
        r := New()
        r.Reqf(Rval{
                Url: "http://" + addr + "/br",
        })
        if !bytes.Equal(r.Respon, []byte("abc强强强强")) {
-               t.Error("br fail")
+               t.Error("br fail", r.Respon)
        }
        r.Reqf(Rval{
                Url: "http://" + addr + "/gzip",
@@ -67,6 +82,10 @@ func Test_req(t *testing.T) {
        if !bytes.Equal(r.Respon, []byte("abc强强强强")) {
                t.Error("flate fail")
        }
+}
+
+func Test_req2(t *testing.T) {
+       r := New()
        {
                e := r.Reqf(Rval{
                        Url:     "http://" + addr + "/to",
@@ -76,16 +95,24 @@ func Test_req(t *testing.T) {
                        t.Error("Timeout fail")
                }
        }
+}
+
+func Test_req4(t *testing.T) {
+       r := New()
        {
                r.Reqf(Rval{
                        Url:     "http://" + addr + "/to",
                        Timeout: 1000,
                        Async:   true,
                })
-               if !IsTimeout(r.Wait()) {
-                       t.Error("Async Timeout fail")
+               if e := r.Wait(); !IsTimeout(e) {
+                       t.Error("Async Timeout fail", e)
                }
        }
+}
+
+func Test_req5(t *testing.T) {
+       r := New()
        {
                c := make(chan []byte)
                r.Reqf(Rval{
@@ -101,9 +128,13 @@ func Test_req(t *testing.T) {
                        }
                }
                if !IsTimeout(r.Wait()) {
-                       t.Error("async cancel fail")
+                       t.Error("async IsTimeout fail")
                }
        }
+}
+
+func Test_req6(t *testing.T) {
+       r := New()
        {
                c := make(chan []byte)
                r.Reqf(Rval{
@@ -136,6 +167,10 @@ func Test_req(t *testing.T) {
                        t.Error("Cancel fail")
                }
        }
+}
+
+func Test_req3(t *testing.T) {
+       r := New()
        {
                rc, wc := io.Pipe()
                c := make(chan struct{})
@@ -180,6 +215,25 @@ func Test_req(t *testing.T) {
                        }
                        close(c)
                }()
+               r.Reqf(Rval{
+                       Url:              "http://" + addr + "/flate",
+                       SaveToPipeWriter: wc,
+               })
+               <-c
+       }
+       {
+               rc, wc := io.Pipe()
+               c := make(chan struct{})
+               go func() {
+                       var buf []byte = make([]byte, 1<<16)
+                       n, _ := rc.Read(buf)
+                       d := buf[:n]
+                       // d, _ := io.ReadAll(rc)
+                       if !bytes.Equal(d, []byte("abc强强强强")) {
+                               t.Error("flate fail")
+                       }
+                       close(c)
+               }()
                r.Reqf(Rval{
                        Url:              "http://" + addr + "/flate",
                        SaveToPipeWriter: wc,
@@ -192,12 +246,9 @@ func Test_req(t *testing.T) {
                        Url:   "http://" + addr + "/flate",
                        Async: true,
                })
-               if len(r.Respon) != 0 {
-                       t.Error("async fail")
-               }
                r.Wait()
                if !bytes.Equal(r.Respon, []byte("abc强强强强")) {
-                       t.Error("async fail")
+                       t.Error("async fail", r.Respon)
                }
        }
        {
@@ -211,12 +262,11 @@ func Test_req(t *testing.T) {
                if len(r.Respon) != 0 {
                        t.Error("io async fail")
                }
-               d, _ := io.ReadAll(rc)
+               var buf []byte = make([]byte, 1<<16)
+               n, _ := rc.Read(buf)
+               d := buf[:n]
                if !bytes.Equal(d, []byte("abc强强强强")) {
-                       t.Error("io async fail")
-               }
-               if !bytes.Equal(r.Respon, []byte("abc强强强强")) {
-                       t.Error("io async fail")
+                       t.Error("io async fail", d)
                }
        }
 }