diff --git a/workgroups.go b/workgroups.go index b132143..198dfe8 100644 --- a/workgroups.go +++ b/workgroups.go @@ -21,12 +21,14 @@ type Dispatcher struct { numWorkers int } -func NewDispatcher(eg *errgroup.Group, numWorkers int) *Dispatcher { +func NewDispatcher(ctx context.Context, numWorkers int) (*Dispatcher, context.Context) { + eg, ctx := errgroup.WithContext(ctx) + return &Dispatcher{ queue: make(chan Job, numWorkers), eg: eg, numWorkers: numWorkers, - } + }, ctx } func (d *Dispatcher) Start(ctx context.Context) { diff --git a/workgroups_test.go b/workgroups_test.go index 9bef8b1..ad06013 100644 --- a/workgroups_test.go +++ b/workgroups_test.go @@ -12,7 +12,6 @@ import ( "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "go.xsfx.dev/workgroups" - "golang.org/x/sync/errgroup" ) func TestDispatcher(t *testing.T) { @@ -32,8 +31,7 @@ func TestDispatcher(t *testing.T) { return nil } - eg, ctx := errgroup.WithContext(context.Background()) - d := workgroups.NewDispatcher(eg, runtime.GOMAXPROCS(0)) + d, ctx := workgroups.NewDispatcher(context.Background(), runtime.GOMAXPROCS(0)) d.Start(ctx) for i := 0; i < 10; i++ { @@ -54,8 +52,7 @@ func TestDispatcherError(t *testing.T) { return fmt.Errorf("this is an error") //nolint:goerr113 } - eg, ctx := errgroup.WithContext(context.Background()) - d := workgroups.NewDispatcher(eg, runtime.GOMAXPROCS(0)) + d, ctx := workgroups.NewDispatcher(context.Background(), runtime.GOMAXPROCS(0)) d.Start(ctx) d.Append(work) d.Close() @@ -76,8 +73,7 @@ func TestDispatcherTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second/2) defer cancel() - eg, ctx := errgroup.WithContext(ctx) - d := workgroups.NewDispatcher(eg, runtime.GOMAXPROCS(0)) + d, ctx := workgroups.NewDispatcher(ctx, runtime.GOMAXPROCS(0)) d.Start(ctx) d.Append(work) d.Close()