]> 127.0.0.1 Git - part/.git/commitdiff
1 (#15) v0.28.20250125173504
authorqydysky <qydysky@foxmail.com>
Sat, 25 Jan 2025 17:29:25 +0000 (01:29 +0800)
committerGitHub <noreply@github.com>
Sat, 25 Jan 2025 17:29:25 +0000 (01:29 +0800)
io/io.go
io/io_test.go
web/Web.go

index 8b808638fa6b44e06be37f692fbb10f5d39909cb..153d6a192cac9efb507435a6b02afb4853f51ea3 100644 (file)
--- a/io/io.go
+++ b/io/io.go
@@ -577,3 +577,41 @@ func ReadAll(r io.Reader, b []byte) ([]byte, error) {
                }
        }
 }
+
+type CacheWriter struct {
+       ctx             context.Context
+       cancelCauseFunc context.CancelCauseFunc
+       w               io.Writer
+       pushLock        atomic.Bool
+       pushBuf         []byte
+}
+
+var ErrBusy = errors.New(`ErrBusy`)
+
+func NewCacheWriter(ws io.Writer, ctx ...context.Context) *CacheWriter {
+       t := CacheWriter{w: ws}
+       ctx = append(ctx, context.Background())
+       t.ctx, t.cancelCauseFunc = context.WithCancelCause(ctx[0])
+       return &t
+}
+
+func (t *CacheWriter) Write(b []byte) (int, error) {
+       select {
+       case <-t.ctx.Done():
+               return 0, t.ctx.Err()
+       default:
+       }
+       if !t.pushLock.CompareAndSwap(false, true) {
+               return 0, ErrBusy
+       }
+       t.pushBuf = append(t.pushBuf[:0], b...)
+       go func() {
+               defer t.pushLock.Store(false)
+               if n, err := t.w.Write(t.pushBuf); err != nil || n == 0 {
+                       if !errors.Is(err, ErrBusy) {
+                               t.cancelCauseFunc(err)
+                       }
+               }
+       }()
+       return len(t.pushBuf), t.ctx.Err()
+}
index e3c02e1fee155b92e9447de06f13a88376dba05a..8dba4e39a0c9af56b4d47e6e6e4399f0c9b32a56 100644 (file)
@@ -2,8 +2,10 @@ package part
 
 import (
        "bytes"
+       "errors"
        "io"
        "testing"
+       "time"
 )
 
 func Test_CopyIO(t *testing.T) {
@@ -96,3 +98,23 @@ func Benchmark_readall1(b *testing.B) {
                r.Reset(data)
        }
 }
+
+func Test_CacheWrite(t *testing.T) {
+       r, w := io.Pipe()
+       rc, _ := RW2Chan(r, nil)
+       go func() {
+               time.Sleep(time.Millisecond * 500)
+               b := <-rc
+               if !bytes.Equal(b, []byte("123")) {
+                       t.Fatal()
+               }
+       }()
+       writer := NewCacheWriter(w)
+       if n, err := writer.Write([]byte("123")); n != 3 || err != nil {
+               t.Fatal()
+       }
+       if _, err := writer.Write([]byte("123")); !errors.Is(err, ErrBusy) {
+               t.Fatal()
+       }
+       time.Sleep(time.Second)
+}
index f261ca874621bd28141324768ac7ad9bc895670f..8493a0bb4c7a098c6c05135d4076c36dd179ded2 100644 (file)
@@ -14,6 +14,7 @@ import (
 
        "github.com/dustin/go-humanize"
        "github.com/google/uuid"
+       pio "github.com/qydysky/part/io"
        psync "github.com/qydysky/part/sync"
        sys "github.com/qydysky/part/sys"
 )
@@ -570,6 +571,29 @@ func (t withflush) WriteHeader(i int) {
        }
 }
 
+type withCache struct {
+       cw  *pio.CacheWriter
+       raw http.ResponseWriter
+}
+
+func (t withCache) Header() http.Header {
+       if t.raw != nil {
+               return t.raw.Header()
+       }
+       return make(http.Header)
+}
+func (t withCache) Write(b []byte) (i int, e error) {
+       if t.cw != nil {
+               return t.cw.Write(b)
+       }
+       return t.raw.Write(b)
+}
+func (t withCache) WriteHeader(i int) {
+       if t.raw != nil {
+               t.raw.WriteHeader(i)
+       }
+}
+
 type Exprier struct {
        max int
        m   psync.Map
@@ -687,6 +711,17 @@ func WithFlush(w http.ResponseWriter) http.ResponseWriter {
        return withflush{w}
 }
 
+// IsCacheBusy
+func WithCache(w http.ResponseWriter) http.ResponseWriter {
+       t := withCache{raw: w}
+       t.cw = pio.NewCacheWriter(w)
+       return t
+}
+
+func IsCacheBusy(e error) bool {
+       return errors.Is(e, pio.ErrBusy)
+}
+
 func WithStatusCode(w http.ResponseWriter, code int) {
        w.WriteHeader(code)
        _, _ = w.Write([]byte(http.StatusText(code)))