Skip to content

Remove the BatchQueue type #1117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 37 additions & 83 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ type Writer struct {
// The default is to use a kafka default value of 1048576.
BatchBytes int64

// The maximum number of batches to buffer before blocking further message
// production. Setting too low a value may reduce throughput or cause the
// producer to block. Setting a higher value will increase memory usage when
// the producer is running faster than the brokers can keep up with.
//
// The default is 0, disabling blocking.
MaxBufferedBatches int

// Time limit on how often incomplete message batches will be flushed to
// kafka.
//
Expand Down Expand Up @@ -915,74 +923,14 @@ func (w *Writer) chooseTopic(msg Message) (string, error) {
return w.Topic, nil
}

type batchQueue struct {
queue []*writeBatch

// Pointers are used here to make `go vet` happy, and avoid copying mutexes.
// It may be better to revert these to non-pointers and avoid the copies in
// a different way.
mutex *sync.Mutex
cond *sync.Cond

closed bool
}

func (b *batchQueue) Put(batch *writeBatch) bool {
b.cond.L.Lock()
defer b.cond.L.Unlock()
defer b.cond.Broadcast()

if b.closed {
return false
}
b.queue = append(b.queue, batch)
return true
}

func (b *batchQueue) Get() *writeBatch {
b.cond.L.Lock()
defer b.cond.L.Unlock()

for len(b.queue) == 0 && !b.closed {
b.cond.Wait()
}

if len(b.queue) == 0 {
return nil
}

batch := b.queue[0]
b.queue[0] = nil
b.queue = b.queue[1:]

return batch
}

func (b *batchQueue) Close() {
b.cond.L.Lock()
defer b.cond.L.Unlock()
defer b.cond.Broadcast()

b.closed = true
}

func newBatchQueue(initialSize int) batchQueue {
bq := batchQueue{
queue: make([]*writeBatch, 0, initialSize),
mutex: &sync.Mutex{},
cond: &sync.Cond{},
}

bq.cond.L = bq.mutex

return bq
}

// partitionWriter is a writer for a topic-partion pair. It maintains messaging order
// partitionWriter is a writer for a topic-partition pair. It maintains messaging order
// across batches of messages.
type partitionWriter struct {
meta topicPartition
queue batchQueue
queue chan *writeBatch
// blockOnQueue provides backpressure for writers which specified a limited
// MaxBufferedBatches. Otherwise, it simulates an unlimited buffer.
blockOnQueue bool

mutex sync.Mutex
currBatch *writeBatch
Expand All @@ -994,25 +942,21 @@ type partitionWriter struct {

func newPartitionWriter(w *Writer, key topicPartition) *partitionWriter {
writer := &partitionWriter{
meta: key,
queue: newBatchQueue(10),
w: w,
meta: key,
w: w,
}
if w.MaxBufferedBatches > 0 {
writer.queue = make(chan *writeBatch, w.MaxBufferedBatches)
writer.blockOnQueue = true
} else {
writer.queue = make(chan *writeBatch)
}
w.spawn(writer.writeBatches)
return writer
}

func (ptw *partitionWriter) writeBatches() {
for {
batch := ptw.queue.Get()

// The only time we can return nil is when the queue is closed
// and empty. If the queue is closed that means
// the Writer is closed so once we're here it's time to exit.
if batch == nil {
return
}

for batch := range ptw.queue {
ptw.writeBatch(batch)
}
}
Expand All @@ -1038,14 +982,14 @@ func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[*
}
if !batch.add(msgs[i], batchSize, batchBytes) {
batch.trigger()
ptw.queue.Put(batch)
ptw.queueBatch(batch)
ptw.currBatch = nil
goto assignMessage
}

if batch.full(batchSize, batchBytes) {
batch.trigger()
ptw.queue.Put(batch)
ptw.queueBatch(batch)
ptw.currBatch = nil
}

Expand All @@ -1056,6 +1000,16 @@ func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[*
return batches
}

func (ptw *partitionWriter) queueBatch(batch *writeBatch) {
if ptw.blockOnQueue {
ptw.queue <- batch
} else {
go func() {
ptw.queue <- batch
}()
}
}

// ptw.w can be accessed here because this is called with the lock ptw.mutex already held.
func (ptw *partitionWriter) newWriteBatch() *writeBatch {
batch := newWriteBatch(time.Now(), ptw.w.batchTimeout())
Expand All @@ -1078,7 +1032,7 @@ func (ptw *partitionWriter) awaitBatch(batch *writeBatch) {
// pw.currBatch != batch so we just move on.
// Otherwise, we detach the batch from the ptWriter and enqueue it for writing.
if ptw.currBatch == batch {
ptw.queue.Put(batch)
ptw.queueBatch(batch)
ptw.currBatch = nil
}
ptw.mutex.Unlock()
Expand Down Expand Up @@ -1182,12 +1136,12 @@ func (ptw *partitionWriter) close() {

if ptw.currBatch != nil {
batch := ptw.currBatch
ptw.queue.Put(batch)
ptw.queue <- batch
ptw.currBatch = nil
batch.trigger()
}

ptw.queue.Close()
close(ptw.queue)
}

type writeBatch struct {
Expand Down
82 changes: 0 additions & 82 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,94 +7,12 @@ import (
"io"
"math"
"strconv"
"sync"
"testing"
"time"

"github.com/segmentio/kafka-go/sasl/plain"
)

func TestBatchQueue(t *testing.T) {
tests := []struct {
scenario string
function func(*testing.T)
}{
{
scenario: "the remaining items in a queue can be gotten after closing",
function: testBatchQueueGetWorksAfterClose,
},
{
scenario: "putting into a closed queue fails",
function: testBatchQueuePutAfterCloseFails,
},
{
scenario: "putting into a queue awakes a goroutine in a get call",
function: testBatchQueuePutWakesSleepingGetter,
},
}

for _, test := range tests {
testFunc := test.function
t.Run(test.scenario, func(t *testing.T) {
t.Parallel()
testFunc(t)
})
}
}

func testBatchQueuePutWakesSleepingGetter(t *testing.T) {
bq := newBatchQueue(10)
var wg sync.WaitGroup
ready := make(chan struct{})
var batch *writeBatch
wg.Add(1)
go func() {
defer wg.Done()
close(ready)
batch = bq.Get()
}()
<-ready
bq.Put(newWriteBatch(time.Now(), time.Hour*100))
wg.Wait()
if batch == nil {
t.Fatal("got nil batch")
}
}

func testBatchQueuePutAfterCloseFails(t *testing.T) {
bq := newBatchQueue(10)
bq.Close()
if put := bq.Put(newWriteBatch(time.Now(), time.Hour*100)); put {
t.Fatal("put batch into closed queue")
}
}

func testBatchQueueGetWorksAfterClose(t *testing.T) {
bq := newBatchQueue(10)
enqueueBatches := []*writeBatch{
newWriteBatch(time.Now(), time.Hour*100),
newWriteBatch(time.Now(), time.Hour*100),
}

for _, batch := range enqueueBatches {
put := bq.Put(batch)
if !put {
t.Fatal("failed to put batch into queue")
}
}

bq.Close()

batchesGotten := 0
for batchesGotten != 2 {
dequeueBatch := bq.Get()
if dequeueBatch == nil {
t.Fatalf("no batch returned from get")
}
batchesGotten++
}
}

func TestWriter(t *testing.T) {
tests := []struct {
scenario string
Expand Down