diff --git a/channelqueue.go b/channelqueue.go index af8277c..9cfc3ae 100644 --- a/channelqueue.go +++ b/channelqueue.go @@ -1,31 +1,86 @@ package channelqueue -import "github.com/gammazero/deque" +import ( + "sync" + + "github.com/gammazero/deque" +) // ChannelQueue uses a queue to buffer data between input and output channels. type ChannelQueue[T any] struct { input, output chan T length chan int capacity int + closeOnce sync.Once } -// New creates a new ChannelQueue with the specified buffer capacity. +type Option[T any] func(*ChannelQueue[T]) + +// WithCapacity sets the limit on the number of unread items that channelqueue +// will hold. Unbuffered behavior is not supported (use a normal channel for +// that), and a value of zero or less configures the default of no limit. // -// A capacity < 0 specifies unlimited capacity. Unbuffered behavior is not -// supported; use a normal channel for that. Use caution if specifying an -// unlimited capacity since storage is still limited by system resources. -func New[T any](capacity int) *ChannelQueue[T] { - if capacity == 0 { - panic("unbuffered behavior not supported") +// Example: +// +// cq := channelqueue.New(channelqueue.WithCapacity[int](64)) +func WithCapacity[T any](n int) func(*ChannelQueue[T]) { + return func(c *ChannelQueue[T]) { + if n < 1 { + n = -1 + } + c.capacity = n } - if capacity < 0 { - capacity = -1 +} + +// WithInput uses an existing channel as the input channel, which is the +// channel used to write to the queue. This is used when buffering items that +// must be read from an existing channel. Be aware that calling Close or +// Shutdown will close this channel. +// +// Example: +// +// in := make(chan int) +// cq := channelqueue.New(channelqueue.WithInput[int](in)) +func WithInput[T any](in chan T) func(*ChannelQueue[T]) { + return func(c *ChannelQueue[T]) { + if in != nil { + c.input = in + } + } +} + +// WithOutput uses an existing channel as the output channel, which is the +// channel used to read from the queue. This is used when buffering items that +// must be written to an existing channel. Be aware that ChannelQueue will +// close this channel when no more items are available. +// +// Example: +// +// out := make(chan int) +// cq := channelqueue.New(channelqueue.WithOutput[int](out)) +func WithOutput[T any](out chan T) func(*ChannelQueue[T]) { + return func(c *ChannelQueue[T]) { + if out != nil { + c.output = out + } } +} + +// New creates a new ChannelQueue that, by default, holds an unbounded number +// of items of the specified type. +func New[T any](options ...Option[T]) *ChannelQueue[T] { cq := &ChannelQueue[T]{ - input: make(chan T), - output: make(chan T), length: make(chan int), - capacity: capacity, + capacity: -1, + } + for _, opt := range options { + opt(cq) + } + if cq.input == nil { + cq.input = make(chan T) + } + if cq.output == nil { + cq.output = make(chan T) } go cq.bufferData() return cq @@ -34,18 +89,25 @@ func New[T any](capacity int) *ChannelQueue[T] { // NewRing creates a new ChannelQueue with the specified buffer capacity, and // circular buffer behavior. When the buffer is full, writing an additional // item discards the oldest buffered item. -func NewRing[T any](capacity int) *ChannelQueue[T] { - if capacity < 1 { - return New[T](capacity) - } - +func NewRing[T any](options ...Option[T]) *ChannelQueue[T] { cq := &ChannelQueue[T]{ - input: make(chan T), - output: make(chan T), length: make(chan int), - capacity: capacity, + capacity: -1, } - if capacity == 1 { + for _, opt := range options { + opt(cq) + } + if cq.capacity < 1 { + // Unbounded ring is the same as an unbounded queue. + return New(WithInput[T](cq.input)) + } + if cq.input == nil { + cq.input = make(chan T) + } + if cq.output == nil { + cq.output = make(chan T) + } + if cq.capacity == 1 { go cq.oneBufferData() } else { go cq.ringBufferData() @@ -68,16 +130,27 @@ func (cq *ChannelQueue[T]) Len() int { return <-cq.length } -// Cap returns the capacity of the channel. +// Cap returns the capacity of the channelqueue. Returns -1 if unbounded. func (cq *ChannelQueue[T]) Cap() int { return cq.capacity } -// Close closes the input channel. Additional input will panic, output will -// continue to be readable until there is no more data, and then the output -// channel is closed. +// Close closes the input channel. This is the same as calling the builtin +// close on the input channel, except Close can be called multiple times.. +// Additional input will panic, output will continue to be readable until there +// is no more data, and then the output channel is closed. func (cq *ChannelQueue[T]) Close() { - close(cq.input) + cq.closeOnce.Do(func() { + close(cq.input) + }) +} + +// Shutdown calls Close then drains the channel to ensure that the internal +// goroutine finishes. +func (cq *ChannelQueue[T]) Shutdown() { + cq.Close() + for range cq.output { + } } // bufferData is the goroutine that transfers data from the In() chan to the diff --git a/channelqueue_test.go b/channelqueue_test.go index 7302ad2..4b3aa1a 100644 --- a/channelqueue_test.go +++ b/channelqueue_test.go @@ -7,19 +7,22 @@ import ( "time" cq "github.com/gammazero/channelqueue" + "go.uber.org/goleak" ) func TestCapLen(t *testing.T) { - ch := cq.New[int](-1) + defer goleak.VerifyNone(t) + + ch := cq.New[int]() if ch.Cap() != -1 { t.Error("expected capacity -1") } + ch.Close() - ch = cq.New[int](3) + ch = cq.New[int](cq.WithCapacity[int](3)) if ch.Cap() != 3 { t.Error("expected capacity 3") } - if ch.Len() != 0 { t.Error("expected 0 from Len()") } @@ -30,21 +33,68 @@ func TestCapLen(t *testing.T) { } in <- i } + ch.Shutdown() + + ch = cq.New(cq.WithCapacity[int](0)) + if ch.Cap() != -1 { + t.Error("expected capacity -1") + } + ch.Close() +} + +func TestExistingInput(t *testing.T) { + defer goleak.VerifyNone(t) + + in := make(chan int, 1) + ch := cq.New(cq.WithInput[int](in), cq.WithCapacity[int](64)) + in <- 42 + x := <-ch.Out() + if x != 42 { + t.Fatal("wrong value") + } + ch.Close() +} - defer func() { - if r := recover(); r == nil { - t.Error("expected panic from capacity 0") +func TestExistingOutput(t *testing.T) { + defer goleak.VerifyNone(t) + + out := make(chan int) + ch := cq.New(cq.WithOutput[int](out)) + ch.In() <- 42 + x := <-out + if x != 42 { + t.Fatal("wrong value") + } + ch.Close() +} + +func TestExistingChannels(t *testing.T) { + defer goleak.VerifyNone(t) + + in := make(chan int) + out := make(chan int) + + // Create a buffer between in and out channels. + cq.New(cq.WithInput[int](in), cq.WithOutput[int](out)) + for i := 0; i <= 100; i++ { + in <- i + } + close(in) // this will close ch when all output is read. + + expect := 0 + for x := range out { + if x != expect { + t.Fatalf("expected %d got %d", expect, x) } - }() - ch = cq.New[int](0) - if ch != nil { - t.Fatal("expected nil") + expect++ } } func TestUnlimitedSpace(t *testing.T) { + defer goleak.VerifyNone(t) + const msgCount = 1000 - ch := cq.New[int](-1) + ch := cq.New[int]() go func() { for i := 0; i < msgCount; i++ { ch.In() <- i @@ -60,8 +110,10 @@ func TestUnlimitedSpace(t *testing.T) { } func TestLimitedSpace(t *testing.T) { + defer goleak.VerifyNone(t) + const msgCount = 1000 - ch := cq.New[int](32) + ch := cq.New(cq.WithCapacity[int](32)) go func() { for i := 0; i < msgCount; i++ { ch.In() <- i @@ -77,23 +129,26 @@ func TestLimitedSpace(t *testing.T) { } func TestBufferLimit(t *testing.T) { - ch := cq.New[int](32) + defer goleak.VerifyNone(t) + + ch := cq.New(cq.WithCapacity[int](32)) + defer ch.Shutdown() + for i := 0; i < ch.Cap(); i++ { ch.In() <- i } - var timeout bool select { case ch.In() <- 999: - case <-time.After(200 * time.Millisecond): - timeout = true - } - if !timeout { t.Fatal("expected timeout on full channel") + case <-time.After(200 * time.Millisecond): } } func TestRace(t *testing.T) { - ch := cq.New[int](-1) + defer goleak.VerifyNone(t) + + ch := cq.New[int]() + defer ch.Shutdown() var err error done := make(chan struct{}) @@ -146,9 +201,11 @@ func TestRace(t *testing.T) { } func TestDouble(t *testing.T) { + defer goleak.VerifyNone(t) + const msgCount = 1000 - ch := cq.New[int](100) - recvCh := cq.New[int](100) + ch := cq.New(cq.WithCapacity[int](100)) + recvCh := cq.New(cq.WithCapacity[int](100)) go func() { for i := 0; i < msgCount; i++ { ch.In() <- i @@ -157,28 +214,41 @@ func TestDouble(t *testing.T) { }() var err error go func() { - for i := 0; i < msgCount; i++ { - val := <-ch.Out() + var i int + for val := range ch.Out() { if i != val { err = fmt.Errorf("expected %d but got %d", i, val) return } recvCh.In() <- i + i++ + } + if i != msgCount { + err = fmt.Errorf("expected %d messages from ch, got %d", msgCount, i) + return } + recvCh.Close() }() - for i := 0; i < msgCount; i++ { - val := <-recvCh.Out() + var i int + for val := range recvCh.Out() { if i != val { t.Fatal("expected", i, "but got", val) } + i++ } if err != nil { t.Fatal(err) } + if i != msgCount { + t.Fatalf("expected %d messages from recvCh, got %d", msgCount, i) + } } func TestDeadlock(t *testing.T) { - ch := cq.New[int](1) + defer goleak.VerifyNone(t) + + ch := cq.New(cq.WithCapacity[int](1)) + defer ch.Shutdown() ch.In() <- 1 <-ch.Out() @@ -196,7 +266,9 @@ func TestDeadlock(t *testing.T) { } func TestRing(t *testing.T) { - ch := cq.NewRing[rune](5) + defer goleak.VerifyNone(t) + + ch := cq.NewRing(cq.WithCapacity[rune](5)) for _, r := range "hello" { ch.In() <- r } @@ -221,19 +293,17 @@ func TestRing(t *testing.T) { t.Fatalf("expected \"fghij\" but got %q", out) } - defer func() { - if r := recover(); r == nil { - t.Error("expected panic from capacity 0") - } - }() - ch = cq.NewRing[rune](0) - if ch != nil { - t.Fatal("expected nil") + ch = cq.NewRing(cq.WithCapacity[rune](0)) + if ch.Cap() != -1 { + t.Fatal("expected -1 capacity") } + ch.Close() } func TestOneRing(t *testing.T) { - ch := cq.NewRing[rune](1) + defer goleak.VerifyNone(t) + + ch := cq.NewRing(cq.WithCapacity[rune](1)) for _, r := range "hello" { ch.In() <- r } @@ -264,19 +334,23 @@ func TestOneRing(t *testing.T) { t.Fatalf("expected \"j\" but got %q", out) } - defer func() { - if r := recover(); r == nil { - t.Error("expected panic from capacity 0") - } - }() - ch = cq.NewRing[rune](0) - if ch != nil { - t.Fatal("expected nil") + ch = cq.NewRing[rune]() + if ch.Cap() != -1 { + t.Fatal("expected -1 capacity") } + ch.Close() +} + +func TestCloseMultiple(t *testing.T) { + ch := cq.New[string]() + ch.Close() + ch.Close() + ch.Shutdown() + ch.Shutdown() } func BenchmarkSerial(b *testing.B) { - ch := cq.New[int](b.N) + ch := cq.New[int]() for i := 0; i < b.N; i++ { ch.In() <- i } @@ -286,7 +360,7 @@ func BenchmarkSerial(b *testing.B) { } func BenchmarkParallel(b *testing.B) { - ch := cq.New[int](b.N) + ch := cq.New[int]() go func() { for i := 0; i < b.N; i++ { <-ch.Out() @@ -300,7 +374,7 @@ func BenchmarkParallel(b *testing.B) { } func BenchmarkPushPull(b *testing.B) { - ch := cq.New[int](b.N) + ch := cq.New[int]() for i := 0; i < b.N; i++ { ch.In() <- i <-ch.Out() diff --git a/go.mod b/go.mod index 1c8287f..e4eb6d6 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module github.com/gammazero/channelqueue go 1.22 require github.com/gammazero/deque v1.0.0 + +require go.uber.org/goleak v1.3.0 diff --git a/go.sum b/go.sum index 4d4d3af..e30a2e1 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=