]> 127.0.0.1 Git - part/.git/commitdiff
fix v0.19.2
authorqydysky <32743305+qydysky@users.noreply.github.com>
Mon, 14 Nov 2022 16:07:19 +0000 (00:07 +0800)
committerqydysky <32743305+qydysky@users.noreply.github.com>
Mon, 14 Nov 2022 16:07:19 +0000 (00:07 +0800)
reqf/Reqf.go
reqf/Reqf_test.go

index 7d622f843ce0b52b8d9ddedca086c7478c332001..639db16d67a3329822d307498a4159694c73f8ce 100644 (file)
@@ -52,6 +52,7 @@ type Req struct {
        running    *signal.Signal
        responBuf  *bytes.Buffer
        responFile *os.File
+       asyncErr   error
 }
 
 func New() *Req {
@@ -69,6 +70,7 @@ func New() *Req {
 // }
 
 func (t *Req) Reqf(val Rval) error {
+
        if val.SaveToChan != nil && len(val.SaveToChan) == 1 && !val.Async {
                panic("must make sure chan size larger then 1 or use Async true")
        }
@@ -79,6 +81,8 @@ func (t *Req) Reqf(val Rval) error {
        t.Respon = []byte{}
        t.Response = nil
        t.UsedTime = 0
+       t.cancel = signal.Init()
+       t.running = signal.Init()
 
        var returnErr error
 
@@ -97,6 +101,23 @@ func (t *Req) Reqf(val Rval) error {
                time.Sleep(time.Duration(SleepTime) * time.Millisecond)
        }
 
+       if !val.Async || returnErr != nil {
+               t.asyncErr = returnErr
+               if val.SaveToChan != nil {
+                       close(val.SaveToChan)
+               }
+               if t.responFile != nil {
+                       t.responFile.Close()
+               }
+               if val.SaveToPipeWriter != nil {
+                       val.SaveToPipeWriter.Close()
+               }
+               if t.responBuf != nil {
+                       t.Respon = t.responBuf.Bytes()
+               }
+               t.running.Done()
+               t.cancel.Done()
+       }
        return returnErr
 }
 
@@ -145,22 +166,21 @@ func (t *Req) Reqf_1(val Rval) (err error) {
        if val.Timeout > 0 {
                cx, cancel = context.WithTimeout(cx, time.Duration(val.Timeout)*time.Millisecond)
        }
-       req, e := http.NewRequest(Method, val.Url, body)
-       if e != nil {
-               panic(e)
-       }
-       req = req.WithContext(cx)
 
-       var done = make(chan struct{})
-       defer close(done)
        go func() {
                select {
                case <-t.cancel.WaitC():
                        cancel()
-               case <-done:
+               case <-t.running.WaitC():
                }
        }()
 
+       req, e := http.NewRequest(Method, val.Url, body)
+       if e != nil {
+               panic(e)
+       }
+       req = req.WithContext(cx)
+
        for _, v := range val.Cookies {
                req.AddCookie(v)
        }
@@ -246,10 +266,9 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                resReader = resp.Body
        }
 
-       t.running = signal.Init()
-       t.cancel = signal.Init()
        go func() {
                buf := make([]byte, 512)
+
                for {
                        if n, e := resReader.Read(buf); n != 0 {
                                w.Write(buf[:n])
@@ -269,6 +288,10 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                                break
                        }
                }
+
+               if val.Async {
+                       t.asyncErr = err
+               }
                resp.Body.Close()
                if val.SaveToChan != nil {
                        close(val.SaveToChan)
@@ -282,7 +305,6 @@ func (t *Req) Reqf_1(val Rval) (err error) {
                if t.responBuf != nil {
                        t.Respon = t.responBuf.Bytes()
                }
-               t.cancel.Done()
                t.running.Done()
        }()
        if !val.Async {
@@ -294,8 +316,9 @@ func (t *Req) Reqf_1(val Rval) (err error) {
        return
 }
 
-func (t *Req) Wait() {
+func (t *Req) Wait() error {
        t.running.Wait()
+       return t.asyncErr
 }
 
 func (t *Req) Cancel() { t.Close() }
index 8c6569a8085c2cc345b9fd66f7d8ab01ba635d16..7113b8b6b194a7a85c6d328d6be98d17c85df74d 100644 (file)
@@ -67,6 +67,43 @@ func Test_req(t *testing.T) {
        if !bytes.Equal(r.Respon, []byte("abc强强强强")) {
                t.Error("flate fail")
        }
+       {
+               e := r.Reqf(Rval{
+                       Url:     "http://" + addr + "/to",
+                       Timeout: 1000,
+               })
+               if !IsTimeout(e) {
+                       t.Error("Timeout fail")
+               }
+       }
+       {
+               r.Reqf(Rval{
+                       Url:     "http://" + addr + "/to",
+                       Timeout: 1000,
+                       Async:   true,
+               })
+               if !IsTimeout(r.Wait()) {
+                       t.Error("Async Timeout fail")
+               }
+       }
+       {
+               c := make(chan []byte)
+               r.Reqf(Rval{
+                       Url:        "http://" + addr + "/to",
+                       Timeout:    1000,
+                       Async:      true,
+                       SaveToChan: c,
+               })
+               for {
+                       buf := <-c
+                       if len(buf) == 0 {
+                               break
+                       }
+               }
+               if !IsTimeout(r.Wait()) {
+                       t.Error("async cancel fail")
+               }
+       }
        {
                c := make(chan []byte)
                r.Reqf(Rval{
@@ -86,15 +123,6 @@ func Test_req(t *testing.T) {
                        t.Error("chan fail")
                }
        }
-       {
-               e := r.Reqf(Rval{
-                       Url:     "http://" + addr + "/to",
-                       Timeout: 1000,
-               })
-               if !IsTimeout(e) {
-                       t.Error("Timeout fail")
-               }
-       }
        {
                timer := time.NewTimer(time.Second)
                go func() {