|
@@ -19,15 +19,38 @@
|
|
|
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;
|
|
|
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
|
|
|
+import com.google.common.collect.Lists;
|
|
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
|
|
+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation;
|
|
|
+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.NMGpuResourceInfo;
|
|
|
+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation;
|
|
|
+import org.junit.Assert;
|
|
|
import org.junit.Test;
|
|
|
+import java.util.List;
|
|
|
|
|
|
public class TestGpuResourcePlugin {
|
|
|
|
|
|
+ private GpuDiscoverer createMockDiscoverer() throws YarnException {
|
|
|
+ GpuDiscoverer gpuDiscoverer = mock(GpuDiscoverer.class);
|
|
|
+ when(gpuDiscoverer.isAutoDiscoveryEnabled()).thenReturn(true);
|
|
|
+
|
|
|
+ PerGpuDeviceInformation gpu =
|
|
|
+ new PerGpuDeviceInformation();
|
|
|
+ gpu.setProductName("testGpu");
|
|
|
+ List<PerGpuDeviceInformation> gpus = Lists.newArrayList();
|
|
|
+ gpus.add(gpu);
|
|
|
+
|
|
|
+ GpuDeviceInformation gpuDeviceInfo = new GpuDeviceInformation();
|
|
|
+ gpuDeviceInfo.setGpus(gpus);
|
|
|
+ when(gpuDiscoverer.getGpuDeviceInformation()).thenReturn(gpuDeviceInfo);
|
|
|
+ return gpuDiscoverer;
|
|
|
+ }
|
|
|
+
|
|
|
@Test(expected = YarnException.class)
|
|
|
public void testResourceHandlerNotInitialized() throws YarnException {
|
|
|
- GpuDiscoverer gpuDiscoverer = mock(GpuDiscoverer.class);
|
|
|
+ GpuDiscoverer gpuDiscoverer = createMockDiscoverer();
|
|
|
GpuNodeResourceUpdateHandler gpuNodeResourceUpdateHandler =
|
|
|
mock(GpuNodeResourceUpdateHandler.class);
|
|
|
|
|
@@ -39,7 +62,7 @@ public class TestGpuResourcePlugin {
|
|
|
|
|
|
@Test
|
|
|
public void testResourceHandlerIsInitialized() throws YarnException {
|
|
|
- GpuDiscoverer gpuDiscoverer = mock(GpuDiscoverer.class);
|
|
|
+ GpuDiscoverer gpuDiscoverer = createMockDiscoverer();
|
|
|
GpuNodeResourceUpdateHandler gpuNodeResourceUpdateHandler =
|
|
|
mock(GpuNodeResourceUpdateHandler.class);
|
|
|
|
|
@@ -51,4 +74,52 @@ public class TestGpuResourcePlugin {
|
|
|
//Not throwing any exception
|
|
|
target.getNMResourceInfo();
|
|
|
}
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testGetNMResourceInfoAutoDiscoveryEnabled()
|
|
|
+ throws YarnException {
|
|
|
+ GpuDiscoverer gpuDiscoverer = createMockDiscoverer();
|
|
|
+
|
|
|
+ GpuNodeResourceUpdateHandler gpuNodeResourceUpdateHandler =
|
|
|
+ mock(GpuNodeResourceUpdateHandler.class);
|
|
|
+
|
|
|
+ GpuResourcePlugin target =
|
|
|
+ new GpuResourcePlugin(gpuNodeResourceUpdateHandler, gpuDiscoverer);
|
|
|
+
|
|
|
+ target.createResourceHandler(null, null, null);
|
|
|
+
|
|
|
+ NMGpuResourceInfo resourceInfo =
|
|
|
+ (NMGpuResourceInfo) target.getNMResourceInfo();
|
|
|
+ Assert.assertNotNull("GpuDeviceInformation should not be null",
|
|
|
+ resourceInfo.getGpuDeviceInformation());
|
|
|
+
|
|
|
+ List<PerGpuDeviceInformation> gpus =
|
|
|
+ resourceInfo.getGpuDeviceInformation().getGpus();
|
|
|
+ Assert.assertNotNull("List of PerGpuDeviceInformation should not be null",
|
|
|
+ gpus);
|
|
|
+
|
|
|
+ Assert.assertEquals("List of PerGpuDeviceInformation should have a " +
|
|
|
+ "size of 1", 1, gpus.size());
|
|
|
+ Assert.assertEquals("Product name of GPU does not match",
|
|
|
+ "testGpu", gpus.get(0).getProductName());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testGetNMResourceInfoAutoDiscoveryDisabled()
|
|
|
+ throws YarnException {
|
|
|
+ GpuDiscoverer gpuDiscoverer = createMockDiscoverer();
|
|
|
+ when(gpuDiscoverer.isAutoDiscoveryEnabled()).thenReturn(false);
|
|
|
+
|
|
|
+ GpuNodeResourceUpdateHandler gpuNodeResourceUpdateHandler =
|
|
|
+ mock(GpuNodeResourceUpdateHandler.class);
|
|
|
+
|
|
|
+ GpuResourcePlugin target =
|
|
|
+ new GpuResourcePlugin(gpuNodeResourceUpdateHandler, gpuDiscoverer);
|
|
|
+
|
|
|
+ target.createResourceHandler(null, null, null);
|
|
|
+
|
|
|
+ NMGpuResourceInfo resourceInfo =
|
|
|
+ (NMGpuResourceInfo) target.getNMResourceInfo();
|
|
|
+ Assert.assertNull(resourceInfo.getGpuDeviceInformation());
|
|
|
+ }
|
|
|
}
|