]> 127.0.0.1 Git - part/.git/commitdiff
1
authorqydysky <qydysky@foxmail.com>
Sat, 13 Jan 2024 22:01:47 +0000 (06:01 +0800)
committerqydysky <qydysky@foxmail.com>
Sat, 13 Jan 2024 22:01:47 +0000 (06:01 +0800)
sync/RWMutex.go
sync/RWMutex_test.go

index b115e020d0bd72cf453c4820d854ee6549e5b3ea..a0f0836d2225e87c0cb66b2e24b46e69fa02fd97 100644 (file)
@@ -1,6 +1,7 @@
 package part
 
 import (
+       "errors"
        "fmt"
        "runtime"
        "strings"
@@ -14,120 +15,120 @@ const (
        rlock int32 = 1
 )
 
+var (
+       ErrTimeoutToLock  = errors.New("ErrTimeoutToLock")
+       ErrTimeoutToULock = errors.New("ErrTimeoutToULock")
+)
+
 type RWMutex struct {
-       rlc  atomic.Int32
-       read atomic.Int32
+       rlc       atomic.Int32
+       PanicFunc func(any)
 }
 
-func parse(i int32) string {
-       switch i {
-       case -1:
-               return "lock"
-       case 0:
-               return "ulock"
-       case 1:
-               return "rlock"
-       }
-       return "unknow"
-}
+// func parse(i int32) string {
+//     switch i {
+//     case -2:
+//             return "lock"
+//     case -1:
+//             return "ulock"
+//     }
+//     return "rlock"
+// }
 
-// i == oldt -> i = t -> pass
-//
-// otherwish block until i == oldt
-func cas(i *atomic.Int32, oldt, t int32, to ...time.Duration) error {
-       c := time.Now()
-       for !i.CompareAndSwap(oldt, t) {
-               if len(to) != 0 && time.Since(c) > to[0] {
-                       return fmt.Errorf("timeout to set %s => %s while is %s", parse(oldt), parse(t), parse(i.Load()))
-               }
-               runtime.Gosched()
-       }
-       return nil
-}
+// // i == oldt -> i = t -> pass
+// //
+// // otherwish block until i == oldt
+// func cas(i *atomic.Int32, oldt, t int32, to ...time.Duration) error {
+//     c := time.Now()
+//     for !i.CompareAndSwap(oldt, t) {
+//             if len(to) != 0 && time.Since(c) > to[0] {
+//                     return fmt.Errorf("timeout to set %s => %s while is %s", parse(oldt), parse(t), parse(i.Load()))
+//             }
+//             runtime.Gosched()
+//     }
+//     return nil
+// }
 
-// i == t -> pass
-//
-// i == oldt -> i = t -> pass
-//
-// otherwish block until i == oldt
-func lcas(i *atomic.Int32, oldt, t int32, to ...time.Duration) error {
-       c := time.Now()
-       for i.Load() != t && !i.CompareAndSwap(oldt, t) {
-               if len(to) != 0 && time.Since(c) > to[0] {
-                       return fmt.Errorf("timeout to set %s => %s while is %s", parse(oldt), parse(t), parse(i.Load()))
-               }
-               runtime.Gosched()
+// // i == t -> pass
+// //
+// // i == oldt -> i = t -> pass
+// //
+// // otherwish block until i == oldt
+// func lcas(i *atomic.Int32, oldt, t int32, to ...time.Duration) error {
+//     c := time.Now()
+//     for i.Load() != t && !i.CompareAndSwap(oldt, t) {
+//             if len(to) != 0 && time.Since(c) > to[0] {
+//                     return fmt.Errorf("timeout to set %s => %s while is %s", parse(oldt), parse(t), parse(i.Load()))
+//             }
+//             runtime.Gosched()
+//     }
+//     return nil
+// }
+
+func (m *RWMutex) panicFunc(s any) {
+       if m.PanicFunc != nil {
+               m.PanicFunc(s)
+       } else {
+               panic(s)
        }
-       return nil
 }
 
 // call inTimeCall() in time or panic(callTree)
-func tof(to time.Duration) (inTimeCall func() (called bool)) {
+func (m *RWMutex) tof(to time.Duration, e error) (inTimeCall func() (called bool)) {
        callTree := getCall(2)
        return time.AfterFunc(to, func() {
-               panic("Locking timeout!\n" + callTree)
+               m.panicFunc(errors.Join(e, errors.New(callTree)))
        }).Stop
 }
 
-// to[0]: wait lock timeout to[1]: run lock timeout
+// to[0]: wait lock timeout
+//
+// to[1]: wait ulock timeout
 //
 // 不要在Rlock内设置变量,有DATA RACE风险
 func (m *RWMutex) RLock(to ...time.Duration) (unlockf func(beforeUlock ...func())) {
-       if m.read.Add(1) == 1 {
-               if e := cas(&m.rlc, ulock, rlock, to...); e != nil {
-                       panic(e)
-               }
-       } else {
-               if e := lcas(&m.rlc, ulock, rlock, to...); e != nil {
-                       panic(e)
-               }
+       if len(to) > 0 {
+               defer m.tof(to[0], ErrTimeoutToLock)()
        }
-       var callC atomic.Bool
-       var done func() (called bool)
-       if len(to) > 1 {
-               done = tof(to[1])
+
+       for m.rlc.Load() < rlock && !m.rlc.CompareAndSwap(ulock, rlock) {
+               runtime.Gosched()
        }
+
+       m.rlc.Add(1)
+
        return func(beforeUlock ...func()) {
-               if !callC.CompareAndSwap(false, true) {
-                       panic("had unlock")
+               if len(to) > 1 {
+                       defer m.tof(to[1], ErrTimeoutToULock)()
                }
-               if done != nil {
-                       defer done()
+               for i := 0; i < len(beforeUlock); i++ {
+                       beforeUlock[i]()
                }
-               if m.read.Add(-1) == 0 {
-                       for i := 0; i < len(beforeUlock); i++ {
-                               beforeUlock[i]()
-                       }
-                       if e := cas(&m.rlc, rlock, ulock, to...); e != nil {
-                               panic(e)
-                       }
+               if m.rlc.Add(-1) == rlock {
+                       m.rlc.CompareAndSwap(rlock, ulock)
                }
        }
 }
 
-// to[0]: wait lock timeout to[1]: run lock timeout
+// to[0]: wait lock timeout
+//
+// to[1]: wait ulock timeout
 func (m *RWMutex) Lock(to ...time.Duration) (unlockf func(beforeUlock ...func())) {
-       if e := cas(&m.rlc, ulock, lock, to...); e != nil {
-               panic(e)
+       if len(to) > 0 {
+               defer m.tof(to[0], ErrTimeoutToLock)()
        }
-       var callC atomic.Bool
-       var done func() (called bool)
-       if len(to) > 1 {
-               done = tof(to[1])
+       for !m.rlc.CompareAndSwap(ulock, lock) {
+               runtime.Gosched()
        }
+
        return func(beforeUlock ...func()) {
-               if !callC.CompareAndSwap(false, true) {
-                       panic("had unlock")
-               }
-               if done != nil {
-                       defer done()
+               if len(to) > 1 {
+                       defer m.tof(to[1], ErrTimeoutToULock)()
                }
                for i := 0; i < len(beforeUlock); i++ {
                        beforeUlock[i]()
                }
-               if e := cas(&m.rlc, lock, ulock, to...); e != nil {
-                       panic(e)
-               }
+               m.rlc.Store(ulock)
        }
 }
 
@@ -136,7 +137,7 @@ func getCall(i int) (calls string) {
                if pc, file, line, ok := runtime.Caller(i); !ok || strings.HasPrefix(file, runtime.GOROOT()) {
                        break
                } else {
-                       calls += fmt.Sprintf("call by %s\n\t%s:%d\n", runtime.FuncForPC(pc).Name(), file, line)
+                       calls += fmt.Sprintf("\ncall by %s\n\t%s:%d", runtime.FuncForPC(pc).Name(), file, line)
                }
        }
        return
index 7f4f3ffbfbfbf0d604a546a9b041008dd4e7805d..91c530ab317a525f716c064159599ced0c06c2a9 100644 (file)
@@ -1,16 +1,15 @@
 package part
 
 import (
+       "errors"
+       "fmt"
        "testing"
        "time"
 )
 
-func check(l *RWMutex, r, read int32) {
-       if l.rlc.Load() != r {
-               panic("rlc")
-       }
-       if l.read.Load() != read {
-               panic("read")
+func check(l *RWMutex, r int32) {
+       if i := l.rlc.Load(); i != r {
+               panic(fmt.Errorf("%v %v", i, r))
        }
 }
 
@@ -29,57 +28,90 @@ func Test0(t *testing.T) {
 // ulock rlock rlock
 func Test1(t *testing.T) {
        var l RWMutex
-       check(&l, ulock, 0)
+       check(&l, 0)
        ul := l.RLock()
-       check(&l, rlock, 1)
+       check(&l, 2)
        ul1 := l.RLock()
-       check(&l, rlock, 2)
+       check(&l, 3)
        ul()
-       check(&l, rlock, 1)
+       check(&l, 2)
+       ul1()
+       check(&l, 0)
+}
+
+func Test4(t *testing.T) {
+       var l RWMutex
+       ul := l.RLock()
+       ul(func() {
+               ul1 := l.RLock()
+               ul1()
+       })
+}
+
+func Test5(t *testing.T) {
+       var l = RWMutex{PanicFunc: func(a any) {
+               if !errors.Is(a.(error), ErrTimeoutToULock) {
+                       t.Fatal(a)
+               }
+       }}
+       ul := l.RLock(time.Second, time.Second)
+       ul(func() {
+               time.Sleep(time.Second * 2)
+       })
+}
+
+func Test8(t *testing.T) {
+       var l = RWMutex{PanicFunc: func(a any) {
+               if !errors.Is(a.(error), ErrTimeoutToLock) {
+                       panic(a)
+               }
+       }}
+       ul := l.Lock()
+       go ul(func() { time.Sleep(time.Second) })
+       ul1 := l.RLock(time.Millisecond*500, time.Second)
        ul1()
-       check(&l, ulock, 0)
 }
 
 // ulock rlock lock
 func Test2(t *testing.T) {
        var l RWMutex
        ul := l.RLock()
-       check(&l, rlock, 1)
+       check(&l, 2)
        time.AfterFunc(time.Second, func() {
-               check(&l, rlock, 1)
+               check(&l, 2)
                ul()
        })
        c := time.Now()
        ul1 := l.Lock()
-       check(&l, lock, 0)
+       check(&l, -1)
        if time.Since(c) < time.Second {
                t.Fail()
        }
        ul1()
-       check(&l, ulock, 0)
+       check(&l, 0)
 }
 
 // ulock lock rlock
 func Test3(t *testing.T) {
        var l RWMutex
        ul := l.Lock()
-       check(&l, lock, 0)
+       check(&l, -1)
        time.AfterFunc(time.Second, func() {
-               check(&l, lock, 1)
+               check(&l, -1)
                ul()
        })
        c := time.Now()
        ul1 := l.RLock()
-       check(&l, rlock, 1)
+       check(&l, 2)
        if time.Since(c) < time.Second {
                t.Fail()
        }
        ul1()
-       check(&l, ulock, 0)
+       check(&l, 0)
 }
 
 func Test6(t *testing.T) {
-       var c = make(chan int, 2)
+       var c = make(chan int, 3)
        var l RWMutex
        ul := l.RLock()
        time.AfterFunc(time.Millisecond*500, func() {
@@ -92,10 +124,13 @@ func Test6(t *testing.T) {
                time.Sleep(time.Second)
        })
        c <- 0
-       if <-c != 0 {
+       if <-c != 1 {
                t.Fatal()
        }
-       if <-c != 1 {
+       if <-c != 2 {
+               t.Fatal()
+       }
+       if <-c != 0 {
                t.Fatal()
        }
 }
@@ -111,6 +146,9 @@ func Test7(t *testing.T) {
        ul1(func() {
                c <- 1
        })
+       if <-c != 0 {
+               t.Fatal()
+       }
        if <-c != 1 {
                t.Fatal()
        }
@@ -127,16 +165,16 @@ func Panic_Test8(t *testing.T) {
 // ulock rlock rlock
 func Panic_Test4(t *testing.T) {
        var l RWMutex
-       check(&l, ulock, 0)
+       check(&l, 0)
        ul := l.RLock(time.Second, time.Second)
-       check(&l, rlock, 1)
+       check(&l, 1)
        ul1 := l.RLock(time.Second, time.Second)
-       check(&l, rlock, 2)
+       check(&l, 2)
        time.Sleep(time.Millisecond * 1500)
        ul()
-       check(&l, rlock, 1)
+       check(&l, 1)
        ul1()
-       check(&l, ulock, 0)
+       check(&l, 0)
        time.Sleep(time.Second * 3)
 }
 
@@ -144,28 +182,36 @@ func Panic_Test4(t *testing.T) {
 func Panic_Test5(t *testing.T) {
        var l RWMutex
        ul := l.RLock()
-       check(&l, rlock, 1)
+       check(&l, 1)
        time.AfterFunc(time.Millisecond*1500, func() {
-               check(&l, rlock, 1)
+               check(&l, 1)
                ul()
        })
        c := time.Now()
        ul1 := l.Lock(time.Second)
-       check(&l, lock, 0)
+       check(&l, 0)
        if time.Since(c) < time.Second {
                t.Fail()
        }
        ul1()
-       check(&l, ulock, 0)
+       check(&l, 0)
 }
 
+/*
+goos: linux
+goarch: amd64
+pkg: github.com/qydysky/part/sync
+cpu: Intel(R) Celeron(R) J4125 CPU @ 2.00GHz
+BenchmarkRlock
+BenchmarkRlock-4
+
+       1000000              1069 ns/op              24 B/op          1 allocs/op
+
+PASS
+*/
 func BenchmarkRlock(b *testing.B) {
        var lock1 RWMutex
-       var a bool
        for i := 0; i < b.N; i++ {
-               ul := lock1.RLock()
-               a = true
-               ul()
+               lock1.Lock()()
        }
-       println(a)
 }