From: qydysky <32743305+qydysky@users.noreply.github.com> Date: Wed, 1 Mar 2023 18:32:56 +0000 (+0800) Subject: Fix reqf X-Git-Tag: v0.23.1 X-Git-Url: http://127.0.0.1:8081/?a=commitdiff_plain;h=b514001b1944a6b3808e4fda938c156093a40c56;p=part%2F.git Fix reqf --- diff --git a/reqf/Reqf.go b/reqf/Reqf.go index 639db16..44f8094 100644 --- a/reqf/Reqf.go +++ b/reqf/Reqf.go @@ -11,6 +11,7 @@ import ( "os" "strconv" "strings" + "sync" "time" flate "compress/flate" @@ -48,61 +49,59 @@ type Req struct { Response *http.Response UsedTime time.Duration - cancel *signal.Signal - running *signal.Signal - responBuf *bytes.Buffer + cancelF func() + cancel *signal.Signal + running *signal.Signal + responFile *os.File - asyncErr error + err error + + l sync.Mutex } func New() *Req { return new(Req) } -// func main(){ -// var _ReqfVal = ReqfVal{ -// Url:url, -// Proxy:proxy, -// Timeout:10, -// Retry:2, -// } -// Reqf(_ReqfVal) -// } - func (t *Req) Reqf(val Rval) error { - - if val.SaveToChan != nil && len(val.SaveToChan) == 1 && !val.Async { - panic("must make sure chan size larger then 1 or use Async true") - } - if val.SaveToPipeWriter != nil && !val.Async { - panic("SaveToPipeWriter must use Async true") - } + t.l.Lock() t.Respon = []byte{} t.Response = nil t.UsedTime = 0 + t.cancelF = nil t.cancel = signal.Init() t.running = signal.Init() + t.responFile = nil + t.err = nil - var returnErr error - - _val := val - - for SleepTime, Retry := _val.SleepTime, _val.Retry; Retry >= 0; Retry -= 1 { - returnErr = t.Reqf_1(_val) + go func() { select { - case <-t.cancel.WaitC(): //cancel - return returnErr - default: - if returnErr == nil { - return nil + case <-t.cancel.Chan: + if t.cancelF != nil { + t.cancelF() } + case <-t.running.Chan: } - time.Sleep(time.Duration(SleepTime) * time.Millisecond) - } + }() + go func() { + beginTime := time.Now() + _val := val + + for SleepTime, Retry := _val.SleepTime, _val.Retry; Retry >= 0; Retry -= 1 { + t.err = t.Reqf_1(_val) + if t.err == nil || IsCancel(t.err) { + break + } + time.Sleep(time.Duration(SleepTime) * time.Millisecond) + } + + t.UsedTime = time.Since(beginTime) + t.running.Done() + }() - if !val.Async || returnErr != nil { - t.asyncErr = returnErr + if !val.Async { + t.Wait() if val.SaveToChan != nil { close(val.SaveToChan) } @@ -112,13 +111,25 @@ func (t *Req) Reqf(val Rval) error { if val.SaveToPipeWriter != nil { val.SaveToPipeWriter.Close() } - if t.responBuf != nil { - t.Respon = t.responBuf.Bytes() - } - t.running.Done() t.cancel.Done() + t.l.Unlock() + } else { + go func() { + t.Wait() + if val.SaveToChan != nil { + close(val.SaveToChan) + } + if t.responFile != nil { + t.responFile.Close() + } + if val.SaveToPipeWriter != nil { + val.SaveToPipeWriter.Close() + } + t.cancel.Done() + t.l.Unlock() + }() } - return returnErr + return t.err } func (t *Req) Reqf_1(val Rval) (err error) { @@ -126,8 +137,6 @@ func (t *Req) Reqf_1(val Rval) (err error) { Header map[string]string = val.Header ) - var beginTime time.Time = time.Now() - var client http.Client if Header == nil { @@ -166,14 +175,7 @@ func (t *Req) Reqf_1(val Rval) (err error) { if val.Timeout > 0 { cx, cancel = context.WithTimeout(cx, time.Duration(val.Timeout)*time.Millisecond) } - - go func() { - select { - case <-t.cancel.WaitC(): - cancel() - case <-t.running.WaitC(): - } - }() + t.cancelF = cancel req, e := http.NewRequest(Method, val.Url, body) if e != nil { @@ -205,19 +207,23 @@ func (t *Req) Reqf_1(val Rval) (err error) { req.Header.Set(k, v) } - resp, err := client.Do(req) - if v, ok := Header["Connection"]; ok && strings.ToLower(v) != "keep-alive" { - defer client.CloseIdleConnections() + if !t.cancel.Islive() { + err = context.Canceled + return } - if err != nil { - return err + resp, e := client.Do(req) + + if e != nil { + err = e + return + } + + if v, ok := Header["Connection"]; ok && strings.ToLower(v) != "keep-alive" { + defer client.CloseIdleConnections() } t.Response = resp - defer func() { - t.UsedTime = time.Since(beginTime) - }() if val.JustResponseCode { return @@ -227,12 +233,14 @@ func (t *Req) Reqf_1(val Rval) (err error) { err = errors.New(strconv.Itoa(resp.StatusCode)) } + var responBuf *bytes.Buffer var ws []io.Writer if val.SaveToPath != "" { - t.responFile, err = os.Create(val.SaveToPath) + t.responFile, e = os.Create(val.SaveToPath) if err != nil { t.responFile.Close() - return err + err = e + return } ws = append(ws, t.responFile) } @@ -240,12 +248,10 @@ func (t *Req) Reqf_1(val Rval) (err error) { ws = append(ws, val.SaveToPipeWriter) } if !val.NoResponse { - if t.responBuf == nil { - t.responBuf = new(bytes.Buffer) - } else { - t.responBuf.Reset() + if responBuf == nil { + responBuf = new(bytes.Buffer) } - ws = append(ws, t.responBuf) + ws = append(ws, responBuf) } w := io.MultiWriter(ws...) @@ -266,59 +272,40 @@ func (t *Req) Reqf_1(val Rval) (err error) { resReader = resp.Body } - go func() { - buf := make([]byte, 512) - - for { - if n, e := resReader.Read(buf); n != 0 { - w.Write(buf[:n]) - select { - case val.SaveToChan <- buf[:n]: - default: - } - } else if e != nil { - if !errors.Is(e, io.EOF) { - err = e - } - break - } + buf := make([]byte, 512) - if !t.cancel.Islive() { - err = context.Canceled - break + for { + if n, e := resReader.Read(buf); n != 0 { + w.Write(buf[:n]) + select { + case val.SaveToChan <- buf[:n]: + default: } + } else if e != nil { + if !errors.Is(e, io.EOF) { + err = e + } + break } - if val.Async { - t.asyncErr = err + if !t.cancel.Islive() { + err = context.Canceled + break } - resp.Body.Close() - if val.SaveToChan != nil { - close(val.SaveToChan) - } - if t.responFile != nil { - t.responFile.Close() - } - if val.SaveToPipeWriter != nil { - val.SaveToPipeWriter.Close() - } - if t.responBuf != nil { - t.Respon = t.responBuf.Bytes() - } - t.running.Done() - }() - if !val.Async { - t.Wait() } - // if _, e := io.Copy(w, resp.Body); e != nil { - // err = e - // } + + resp.Body.Close() + + if responBuf != nil { + t.Respon = responBuf.Bytes() + } + return } func (t *Req) Wait() error { t.running.Wait() - return t.asyncErr + return t.err } func (t *Req) Cancel() { t.Close() } diff --git a/reqf/Reqf_test.go b/reqf/Reqf_test.go index 7113b8b..6c5fe3b 100644 --- a/reqf/Reqf_test.go +++ b/reqf/Reqf_test.go @@ -12,8 +12,9 @@ import ( web "github.com/qydysky/part/web" ) -func Test_req(t *testing.T) { - addr := "127.0.0.1:10001" +var addr = "127.0.0.1:10001" + +func init() { s := web.New(&http.Server{ Addr: addr, WriteTimeout: time.Second * time.Duration(10), @@ -47,13 +48,27 @@ func Test_req(t *testing.T) { s.Server.Shutdown(context.Background()) }, }) +} + +func Test_req7(t *testing.T) { + r := New() + r.Reqf(Rval{ + Url: "http://" + addr + "/to", + Async: true, + }) + r.Cancel() + if !IsCancel(r.Wait()) { + t.Error("async Cancel fail") + } +} +func Test_req(t *testing.T) { r := New() r.Reqf(Rval{ Url: "http://" + addr + "/br", }) if !bytes.Equal(r.Respon, []byte("abc强强强强")) { - t.Error("br fail") + t.Error("br fail", r.Respon) } r.Reqf(Rval{ Url: "http://" + addr + "/gzip", @@ -67,6 +82,10 @@ func Test_req(t *testing.T) { if !bytes.Equal(r.Respon, []byte("abc强强强强")) { t.Error("flate fail") } +} + +func Test_req2(t *testing.T) { + r := New() { e := r.Reqf(Rval{ Url: "http://" + addr + "/to", @@ -76,16 +95,24 @@ func Test_req(t *testing.T) { t.Error("Timeout fail") } } +} + +func Test_req4(t *testing.T) { + r := New() { r.Reqf(Rval{ Url: "http://" + addr + "/to", Timeout: 1000, Async: true, }) - if !IsTimeout(r.Wait()) { - t.Error("Async Timeout fail") + if e := r.Wait(); !IsTimeout(e) { + t.Error("Async Timeout fail", e) } } +} + +func Test_req5(t *testing.T) { + r := New() { c := make(chan []byte) r.Reqf(Rval{ @@ -101,9 +128,13 @@ func Test_req(t *testing.T) { } } if !IsTimeout(r.Wait()) { - t.Error("async cancel fail") + t.Error("async IsTimeout fail") } } +} + +func Test_req6(t *testing.T) { + r := New() { c := make(chan []byte) r.Reqf(Rval{ @@ -136,6 +167,10 @@ func Test_req(t *testing.T) { t.Error("Cancel fail") } } +} + +func Test_req3(t *testing.T) { + r := New() { rc, wc := io.Pipe() c := make(chan struct{}) @@ -180,6 +215,25 @@ func Test_req(t *testing.T) { } close(c) }() + r.Reqf(Rval{ + Url: "http://" + addr + "/flate", + SaveToPipeWriter: wc, + }) + <-c + } + { + rc, wc := io.Pipe() + c := make(chan struct{}) + go func() { + var buf []byte = make([]byte, 1<<16) + n, _ := rc.Read(buf) + d := buf[:n] + // d, _ := io.ReadAll(rc) + if !bytes.Equal(d, []byte("abc强强强强")) { + t.Error("flate fail") + } + close(c) + }() r.Reqf(Rval{ Url: "http://" + addr + "/flate", SaveToPipeWriter: wc, @@ -192,12 +246,9 @@ func Test_req(t *testing.T) { Url: "http://" + addr + "/flate", Async: true, }) - if len(r.Respon) != 0 { - t.Error("async fail") - } r.Wait() if !bytes.Equal(r.Respon, []byte("abc强强强强")) { - t.Error("async fail") + t.Error("async fail", r.Respon) } } { @@ -211,12 +262,11 @@ func Test_req(t *testing.T) { if len(r.Respon) != 0 { t.Error("io async fail") } - d, _ := io.ReadAll(rc) + var buf []byte = make([]byte, 1<<16) + n, _ := rc.Read(buf) + d := buf[:n] if !bytes.Equal(d, []byte("abc强强强强")) { - t.Error("io async fail") - } - if !bytes.Equal(r.Respon, []byte("abc强强强强")) { - t.Error("io async fail") + t.Error("io async fail", d) } } }