diff --git a/util/locks/prioritymutex.go b/util/locks/prioritymutex.go index 89c7a93ef..b9ae7a015 100644 --- a/util/locks/prioritymutex.go +++ b/util/locks/prioritymutex.go @@ -27,13 +27,13 @@ import ( // the read lock. type PriorityMutex struct { dataMutex sync.RWMutex - highPriorityWaiting sync.WaitGroup + highPriorityWaiting *waitGroup lowPriorityMutex sync.Mutex } func NewPriorityMutex() *PriorityMutex { lock := PriorityMutex{ - highPriorityWaiting: sync.WaitGroup{}, + highPriorityWaiting: newWaitGroup(), } return &lock } @@ -41,7 +41,7 @@ func NewPriorityMutex() *PriorityMutex { // LowPriorityWriteLock acquires a low-priority write lock. func (mtx *PriorityMutex) LowPriorityWriteLock() { mtx.lowPriorityMutex.Lock() - mtx.highPriorityWaiting.Wait() + mtx.highPriorityWaiting.wait() mtx.dataMutex.Lock() } @@ -53,26 +53,26 @@ func (mtx *PriorityMutex) LowPriorityWriteUnlock() { // HighPriorityWriteLock acquires a high-priority write lock. func (mtx *PriorityMutex) HighPriorityWriteLock() { - mtx.highPriorityWaiting.Add(1) + mtx.highPriorityWaiting.add() mtx.dataMutex.Lock() } // HighPriorityWriteUnlock unlocks the high-priority write lock func (mtx *PriorityMutex) HighPriorityWriteUnlock() { mtx.dataMutex.Unlock() - mtx.highPriorityWaiting.Done() + mtx.highPriorityWaiting.done() } // HighPriorityReadLock acquires a high-priority read // lock. func (mtx *PriorityMutex) HighPriorityReadLock() { - mtx.highPriorityWaiting.Add(1) + mtx.highPriorityWaiting.add() mtx.dataMutex.RLock() } // HighPriorityWriteUnlock unlocks the high-priority read // lock func (mtx *PriorityMutex) HighPriorityReadUnlock() { - mtx.highPriorityWaiting.Done() + mtx.highPriorityWaiting.done() mtx.dataMutex.RUnlock() } diff --git a/util/locks/waitgroup.go b/util/locks/waitgroup.go new file mode 100644 index 000000000..3356ac666 --- /dev/null +++ b/util/locks/waitgroup.go @@ -0,0 +1,39 @@ +package locks + +import ( + "sync" + "sync/atomic" +) + +type waitGroup struct { + counter int64 + waitCond *sync.Cond +} + +func newWaitGroup() *waitGroup { + return &waitGroup{ + waitCond: sync.NewCond(&sync.Mutex{}), + } +} + +func (wg *waitGroup) add() { + atomic.AddInt64(&wg.counter, 1) +} + +func (wg *waitGroup) done() { + counter := atomic.AddInt64(&wg.counter, -1) + if counter < 0 { + panic("negative values for wg.counter are not allowed. This was likely caused by calling done() before add()") + } + if atomic.LoadInt64(&wg.counter) == 0 { + wg.waitCond.Signal() + } +} + +func (wg *waitGroup) wait() { + wg.waitCond.L.Lock() + defer wg.waitCond.L.Unlock() + for atomic.LoadInt64(&wg.counter) != 0 { + wg.waitCond.Wait() + } +}