diff --git a/accelerators/nvidia.go b/accelerators/nvidia.go index 054d206b..496feba5 100644 --- a/accelerators/nvidia.go +++ b/accelerators/nvidia.go @@ -31,7 +31,10 @@ import ( ) type NvidiaManager struct { - sync.RWMutex + sync.Mutex + + // true if there are NVIDIA devices present on the node + devicesPresent bool // true if the NVML library (libnvidia-ml.so.1) was loaded successfully nvmlInitialized bool @@ -51,20 +54,9 @@ func (nm *NvidiaManager) Setup() { return } - nm.initializeNVML() - if nm.nvmlInitialized { - return - } - go func() { - glog.V(2).Info("Starting goroutine to initialize NVML") - // TODO: use globalHousekeepingInterval - for range time.Tick(time.Minute) { - nm.initializeNVML() - if nm.nvmlInitialized { - return - } - } - }() + nm.devicesPresent = true + + initializeNVML(nm) } // detectDevices returns true if a device with given pci id is present on the node. @@ -91,20 +83,18 @@ func detectDevices(vendorId string) bool { } // initializeNVML initializes the NVML library and sets up the nvmlDevices map. -func (nm *NvidiaManager) initializeNVML() { +// This is defined as a variable to help in testing. +var initializeNVML = func(nm *NvidiaManager) { if err := gonvml.Initialize(); err != nil { // This is under a logging level because otherwise we may cause // log spam if the drivers/nvml is not installed on the system. glog.V(4).Infof("Could not initialize NVML: %v", err) return } + nm.nvmlInitialized = true numDevices, err := gonvml.DeviceCount() if err != nil { glog.Warningf("GPU metrics would not be available. Failed to get the number of nvidia devices: %v", err) - nm.Lock() - // Even though we won't have GPU metrics, the library was initialized and should be shutdown when exiting. - nm.nvmlInitialized = true - nm.Unlock() return } glog.V(1).Infof("NVML initialized. Number of nvidia devices: %v", numDevices) @@ -122,10 +112,6 @@ func (nm *NvidiaManager) initializeNVML() { } nm.nvidiaDevices[int(minorNumber)] = device } - nm.Lock() - // Doing this at the end to avoid race in accessing nvidiaDevices in GetCollector. - nm.nvmlInitialized = true - nm.Unlock() } // Destroy shuts down NVML. @@ -139,12 +125,21 @@ func (nm *NvidiaManager) Destroy() { // present in the devices.list file in the given devicesCgroupPath. func (nm *NvidiaManager) GetCollector(devicesCgroupPath string) (AcceleratorCollector, error) { nc := &NvidiaCollector{} - nm.RLock() - if !nm.nvmlInitialized || len(nm.nvidiaDevices) == 0 { - nm.RUnlock() + + if !nm.devicesPresent { return nc, nil } - nm.RUnlock() + // Makes sure that we don't call initializeNVML() concurrently and + // that we only call initializeNVML() when it's not initialized. + nm.Lock() + if !nm.nvmlInitialized { + initializeNVML(nm) + } + if !nm.nvmlInitialized || len(nm.nvidiaDevices) == 0 { + nm.Unlock() + return nc, nil + } + nm.Unlock() nvidiaMinorNumbers, err := parseDevicesCgroup(devicesCgroupPath) if err != nil { return nc, err diff --git a/accelerators/nvidia_test.go b/accelerators/nvidia_test.go index c054433f..b7e7c4d6 100644 --- a/accelerators/nvidia_test.go +++ b/accelerators/nvidia_test.go @@ -71,13 +71,16 @@ func TestGetCollector(t *testing.T) { return []int{2, 3}, nil } parseDevicesCgroup = mockParser + originalInitializeNVML := initializeNVML + initializeNVML = func(_ *NvidiaManager) {} defer func() { parseDevicesCgroup = originalParser + initializeNVML = originalInitializeNVML }() nm := &NvidiaManager{} - // When nvmlInitialized is false, empty collector should be returned. + // When devicesPresent is false, empty collector should be returned. ac, err := nm.GetCollector("does-not-matter") assert.Nil(t, err) assert.NotNil(t, ac) @@ -85,6 +88,15 @@ func TestGetCollector(t *testing.T) { assert.True(t, ok) assert.Equal(t, 0, len(nc.Devices)) + // When nvmlInitialized is false, empty collector should be returned. + nm.devicesPresent = true + ac, err = nm.GetCollector("does-not-matter") + assert.Nil(t, err) + assert.NotNil(t, ac) + nc, ok = ac.(*NvidiaCollector) + assert.True(t, ok) + assert.Equal(t, 0, len(nc.Devices)) + // When nvidiaDevices is empty, empty collector should be returned. nm.nvmlInitialized = true ac, err = nm.GetCollector("does-not-matter")