]> 127.0.0.1 Git - part/.git/commitdiff
add
authorqydysky <qydysky@foxmail.com>
Tue, 16 May 2023 20:01:17 +0000 (04:01 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 16 May 2023 20:01:17 +0000 (04:01 +0800)
sync/RWMutex.go

index 712cc88f156d9e3b8cc111e0e5428f20f4332923..0b71740b68e26ffbf49c25894b7c1f68404fe764 100644 (file)
@@ -13,10 +13,9 @@ const (
 )
 
 type RWMutex struct {
-       rlc      atomic.Int32
-       cul      atomic.Int32
-       oll      atomic.Int32
-       wantRead atomic.Bool
+       rlc       atomic.Int32
+       wantRead  atomic.Bool
+       wantWrite atomic.Bool
 }
 
 // RLock() 必须在 lock期间操作的变量所定义的goroutime 中调用
@@ -34,7 +33,7 @@ func (m *RWMutex) RLock(to ...time.Duration) (lockf func() (unlockf func())) {
                                }
                        }
                        c := time.Now()
-                       for m.rlc.Load() < ulock {
+                       for m.rlc.Load() < ulock || m.wantWrite.Load() {
                                if time.Since(c) > to[0] {
                                        panic(fmt.Sprintf("timeout to wait lock, rlc:%d", m.rlc.Load()))
                                }
@@ -54,8 +53,7 @@ func (m *RWMutex) RLock(to ...time.Duration) (lockf func() (unlockf func())) {
                                }
                        }()
                } else {
-                       for m.rlc.Load() < ulock {
-                               time.Sleep(time.Millisecond)
+                       for m.rlc.Load() < ulock || m.wantWrite.Load() {
                                runtime.Gosched()
                        }
                }
@@ -73,7 +71,7 @@ func (m *RWMutex) RLock(to ...time.Duration) (lockf func() (unlockf func())) {
 
 // Lock() 必须在 lock期间操作的变量所定义的goroutime 中调用
 func (m *RWMutex) Lock(to ...time.Duration) (lockf func() (unlockf func())) {
-       lockid := m.cul.Add(1)
+       m.wantWrite.Store(true)
        return func() (unlock func()) {
                var callC atomic.Bool
                if len(to) > 0 {
@@ -86,21 +84,12 @@ func (m *RWMutex) Lock(to ...time.Duration) (lockf func() (unlockf func())) {
                                }
                        }
                        c := time.Now()
-                       for m.rlc.Load() > ulock || m.wantRead.Load() {
+                       for m.rlc.Load() != ulock || m.wantRead.Load() {
                                if time.Since(c) > to[0] {
                                        panic(fmt.Sprintf("timeout to wait rlock, rlc:%d", m.rlc.Load()))
                                }
                                runtime.Gosched()
                        }
-                       for lockid-1 != m.oll.Load() {
-                               if time.Since(c) > to[0] {
-                                       panic(fmt.Sprintf("timeout to wait lock, lockid:%d <> %d", lockid, m.oll.Load()))
-                               }
-                               runtime.Gosched()
-                       }
-                       if !m.rlc.CompareAndSwap(ulock, lock) {
-                               panic(fmt.Sprintf("csa error, rlc:%d", m.rlc.Load()))
-                       }
                        c = time.Now()
                        go func() {
                                for !callC.Load() {
@@ -115,26 +104,18 @@ func (m *RWMutex) Lock(to ...time.Duration) (lockf func() (unlockf func())) {
                                }
                        }()
                } else {
-                       for m.rlc.Load() > ulock || m.wantRead.Load() {
-                               time.Sleep(time.Millisecond)
-                               runtime.Gosched()
-                       }
-                       for lockid-1 != m.oll.Load() {
-                               time.Sleep(time.Millisecond)
+                       for m.rlc.Load() != ulock || m.wantRead.Load() {
                                runtime.Gosched()
                        }
-                       if !m.rlc.CompareAndSwap(ulock, lock) {
-                               panic(fmt.Sprintf("csa error, rlc:%d", m.rlc.Load()))
-                       }
                }
+               m.rlc.Add(-1)
                return func() {
                        if !callC.CompareAndSwap(false, true) {
                                panic("had unlock")
                        }
-                       if !m.rlc.CompareAndSwap(lock, ulock) {
-                               panic(fmt.Sprintf("csa error, rlc:%d", m.rlc.Load()))
+                       if m.rlc.Add(1) == ulock {
+                               m.wantWrite.Store(false)
                        }
-                       m.oll.Store(lockid)
                }
        }
 }