diff --git a/accelerators/nvidia.go b/accelerators/nvidia.go index 9f688ac1..b3c0d139 100644 --- a/accelerators/nvidia.go +++ b/accelerators/nvidia.go @@ -129,7 +129,7 @@ func (nm *nvidiaManager) Destroy() { // GetCollector returns a collector that can fetch nvidia gpu metrics for nvidia devices // present in the devices.list file in the given devicesCgroupPath. func (nm *nvidiaManager) GetCollector(devicesCgroupPath string) (stats.Collector, error) { - nc := &NvidiaCollector{} + nc := &nvidiaCollector{} if !nm.devicesPresent { return nc, nil @@ -154,7 +154,7 @@ func (nm *nvidiaManager) GetCollector(devicesCgroupPath string) (stats.Collector if !ok { return nc, fmt.Errorf("nvidia device minor number %d not found in cached devices", minor) } - nc.Devices = append(nc.Devices, device) + nc.devices = append(nc.devices, device) } return nc, nil } @@ -213,14 +213,18 @@ var parseDevicesCgroup = func(devicesCgroupPath string) ([]int, error) { return nvidiaMinorNumbers, nil } -type NvidiaCollector struct { +type nvidiaCollector struct { // Exposed for testing - Devices []gonvml.Device + devices []gonvml.Device +} + +func NewNvidiaCollector(devices []gonvml.Device) stats.Collector { + return &nvidiaCollector{devices: devices} } // UpdateStats updates the stats for NVIDIA GPUs (if any) attached to the container. -func (nc *NvidiaCollector) UpdateStats(stats *info.ContainerStats) error { - for _, device := range nc.Devices { +func (nc *nvidiaCollector) UpdateStats(stats *info.ContainerStats) error { + for _, device := range nc.devices { model, err := device.Name() if err != nil { return fmt.Errorf("error while getting gpu name: %v", err) diff --git a/accelerators/nvidia_test.go b/accelerators/nvidia_test.go index 92f4f1af..5250c5e4 100644 --- a/accelerators/nvidia_test.go +++ b/accelerators/nvidia_test.go @@ -84,27 +84,27 @@ func TestGetCollector(t *testing.T) { ac, err := nm.GetCollector("does-not-matter") assert.Nil(t, err) assert.NotNil(t, ac) - nc, ok := ac.(*NvidiaCollector) + nc, ok := ac.(*nvidiaCollector) assert.True(t, ok) - assert.Equal(t, 0, len(nc.Devices)) + assert.Equal(t, 0, len(nc.devices)) // When nvmlInitialized is false, empty collector should be returned. nm.devicesPresent = true ac, err = nm.GetCollector("does-not-matter") assert.Nil(t, err) assert.NotNil(t, ac) - nc, ok = ac.(*NvidiaCollector) + nc, ok = ac.(*nvidiaCollector) assert.True(t, ok) - assert.Equal(t, 0, len(nc.Devices)) + assert.Equal(t, 0, len(nc.devices)) // When nvidiaDevices is empty, empty collector should be returned. nm.nvmlInitialized = true ac, err = nm.GetCollector("does-not-matter") assert.Nil(t, err) assert.NotNil(t, ac) - nc, ok = ac.(*NvidiaCollector) + nc, ok = ac.(*nvidiaCollector) assert.True(t, ok) - assert.Equal(t, 0, len(nc.Devices)) + assert.Equal(t, 0, len(nc.devices)) // nvidiaDevices contains devices but they are different than what // is returned by parseDevicesCgroup. We should get an error. @@ -112,9 +112,9 @@ func TestGetCollector(t *testing.T) { ac, err = nm.GetCollector("does-not-matter") assert.NotNil(t, err) assert.NotNil(t, ac) - nc, ok = ac.(*NvidiaCollector) + nc, ok = ac.(*nvidiaCollector) assert.True(t, ok) - assert.Equal(t, 0, len(nc.Devices)) + assert.Equal(t, 0, len(nc.devices)) // nvidiaDevices contains devices returned by parseDevicesCgroup. // No error should be returned and collectors devices array should be @@ -124,9 +124,9 @@ func TestGetCollector(t *testing.T) { ac, err = nm.GetCollector("does-not-matter") assert.Nil(t, err) assert.NotNil(t, ac) - nc, ok = ac.(*NvidiaCollector) + nc, ok = ac.(*nvidiaCollector) assert.True(t, ok) - assert.Equal(t, 2, len(nc.Devices)) + assert.Equal(t, 2, len(nc.devices)) } func TestParseDevicesCgroup(t *testing.T) { diff --git a/manager/container_test.go b/manager/container_test.go index 85793f05..a3b5d113 100644 --- a/manager/container_test.go +++ b/manager/container_test.go @@ -217,7 +217,7 @@ func TestUpdateNvidiaStats(t *testing.T) { stats := info.ContainerStats{} // When there are no devices, we should not get an error and stats should not change. - cd.nvidiaCollector = &accelerators.NvidiaCollector{} + cd.nvidiaCollector = accelerators.NewNvidiaCollector([]gonvml.Device{}) err := cd.nvidiaCollector.UpdateStats(&stats) assert.Nil(t, err) assert.Equal(t, info.ContainerStats{}, stats) @@ -225,7 +225,7 @@ func TestUpdateNvidiaStats(t *testing.T) { // This is an impossible situation (there are devices but nvml is not initialized). // Here I am testing that the CGo gonvml library doesn't panic when passed bad // input and instead returns an error. - cd.nvidiaCollector = &accelerators.NvidiaCollector{Devices: []gonvml.Device{{}, {}}} + cd.nvidiaCollector = accelerators.NewNvidiaCollector([]gonvml.Device{{}, {}}) err = cd.nvidiaCollector.UpdateStats(&stats) assert.NotNil(t, err) assert.Equal(t, info.ContainerStats{}, stats)