diff --git a/device/alignment_test.go b/device/alignment_test.go index 5587cbe..46baeb1 100644 --- a/device/alignment_test.go +++ b/device/alignment_test.go @@ -42,7 +42,6 @@ func TestPeerAlignment(t *testing.T) { checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) } - // TestDeviceAlignment checks that atomically-accessed fields are // aligned to 64-bit boundaries, as required by the atomic package. // diff --git a/device/device.go b/device/device.go index bac361e..5f36036 100644 --- a/device/device.go +++ b/device/device.go @@ -67,12 +67,9 @@ type Device struct { } pool struct { - messageBufferPool *sync.Pool - messageBufferReuseChan chan *[MaxMessageSize]byte - inboundElementPool *sync.Pool - inboundElementReuseChan chan *QueueInboundElement - outboundElementPool *sync.Pool - outboundElementReuseChan chan *QueueOutboundElement + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { diff --git a/device/pools.go b/device/pools.go index eb6d6be..f1d1fa0 100644 --- a/device/pools.go +++ b/device/pools.go @@ -5,87 +5,80 @@ package device -import "sync" +import ( + "sync" + "sync/atomic" +) + +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count uint32 + max uint32 +} + +func NewWaitPool(max uint32, new func() interface{}) *WaitPool { + p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +func (p *WaitPool) Get() interface{} { + if p.max != 0 { + p.lock.Lock() + for atomic.LoadUint32(&p.count) >= p.max { + p.cond.Wait() + } + atomic.AddUint32(&p.count, 1) + p.lock.Unlock() + } + return p.pool.Get() +} + +func (p *WaitPool) Put(x interface{}) { + p.pool.Put(x) + if p.max == 0 { + return + } + atomic.AddUint32(&p.count, ^uint32(0)) + p.cond.Signal() +} func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i++ { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i++ { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i++ { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) - } - } + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + return new([MaxMessageSize]byte) + }) + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + return new(QueueInboundElement) + }) + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + return new(QueueOutboundElement) + }) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan - } + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) } func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } + device.pool.messageBuffers.Put(msg) } func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan - } + return device.pool.inboundElements.Get().(*QueueInboundElement) } func (device *Device) PutInboundElement(elem *QueueInboundElement) { elem.clearPointers() - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(elem) - } else { - device.pool.inboundElementReuseChan <- elem - } + device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } + return device.pool.outboundElements.Get().(*QueueOutboundElement) } func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { elem.clearPointers() - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(elem) - } else { - device.pool.outboundElementReuseChan <- elem - } + device.pool.outboundElements.Put(elem) } diff --git a/device/pools_test.go b/device/pools_test.go new file mode 100644 index 0000000..e6cbac5 --- /dev/null +++ b/device/pools_test.go @@ -0,0 +1,60 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestWaitPool(t *testing.T) { + var wg sync.WaitGroup + trials := int32(100000) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + t.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) + wg.Add(workers) + max := uint32(0) + updateMax := func() { + count := atomic.LoadUint32(&p.count) + if count > p.max { + t.Errorf("count (%d) > max (%d)", count, p.max) + } + for { + old := atomic.LoadUint32(&max) + if count <= old { + break + } + if atomic.CompareAndSwapUint32(&max, old, count) { + break + } + } + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for atomic.AddInt32(&trials, -1) > 0 { + updateMax() + x := p.Get() + updateMax() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + updateMax() + p.Put(x) + updateMax() + } + }() + } + wg.Wait() + if max != p.max { + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + } +}