Browse Source

YARN-7224. Support GPU isolation for docker container. Contributed by Wangda Tan.

Sunil G 7 years ago
parent
commit
9114d7a5a0
39 changed files with 1721 additions and 260 deletions
  1. 1 0
      hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg
  2. 29 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java
  3. 34 8
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml
  4. 1 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/LinuxContainerExecutor.java
  5. 45 57
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/GpuResourceAllocator.java
  6. 59 31
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/GpuResourceHandlerImpl.java
  7. 2 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DefaultLinuxContainerRuntime.java
  8. 5 4
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DelegatingLinuxContainerRuntime.java
  9. 86 5
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DockerLinuxContainerRuntime.java
  10. 3 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/JavaSandboxLinuxContainerRuntime.java
  11. 3 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/LinuxContainerRuntime.java
  12. 5 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/DockerRunCommand.java
  13. 49 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/DockerVolumeCommand.java
  14. 59 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/DockerCommandPlugin.java
  15. 11 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/ResourcePlugin.java
  16. 78 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java
  17. 20 10
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java
  18. 41 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java
  19. 6 4
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuNodeResourceUpdateHandler.java
  20. 9 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuResourcePlugin.java
  21. 319 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV1CommandPlugin.java
  22. 33 29
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMLeveldbStateStoreService.java
  23. 2 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMNullStateStoreService.java
  24. 13 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMStateStoreService.java
  25. 130 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/utils/docker-util.c
  26. 17 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/utils/docker-util.h
  27. 42 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/utils/test_docker_util.cc
  28. 3 3
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestLinuxContainerExecutorWithMocks.java
  29. 5 4
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManagerRecovery.java
  30. 109 47
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandler.java
  31. 7 7
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestDelegatingLinuxContainerRuntime.java
  32. 180 24
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestDockerContainerRuntime.java
  33. 1 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestJavaSandboxLinuxContainerRuntime.java
  34. 2 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/TestDockerCommandExecutor.java
  35. 45 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/TestDockerVolumeCommand.java
  36. 26 8
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java
  37. 217 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV1CommandPlugin.java
  38. 6 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMMemoryStateStoreService.java
  39. 18 4
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/recovery/TestNMLeveldbStateStoreService.java

+ 1 - 0
hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg

@@ -14,3 +14,4 @@ feature.tc.enabled=0
 #  docker.allowed.ro-mounts=## comma seperated volumes that can be mounted as read-only
 #  docker.allowed.ro-mounts=## comma seperated volumes that can be mounted as read-only
 #  docker.allowed.rw-mounts=## comma seperate volumes that can be mounted as read-write, add the yarn local and log dirs to this list to run Hadoop jobs
 #  docker.allowed.rw-mounts=## comma seperate volumes that can be mounted as read-write, add the yarn local and log dirs to this list to run Hadoop jobs
 #  docker.privileged-containers.enabled=0
 #  docker.privileged-containers.enabled=0
+#  docker.allowed.volume-drivers=## comma seperated list of allowed volume-drivers

+ 29 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java

@@ -1483,6 +1483,35 @@ public class YarnConfiguration extends Configuration {
   @Private
   @Private
   public static final String DEFAULT_NM_GPU_PATH_TO_EXEC = "";
   public static final String DEFAULT_NM_GPU_PATH_TO_EXEC = "";
 
 
+  /**
+   * Settings to control which implementation of docker plugin for GPU will be
+   * used.
+   *
+   * By default uses NVIDIA docker v1.
+   */
+  @Private
+  public static final String NM_GPU_DOCKER_PLUGIN_IMPL =
+      NM_GPU_RESOURCE_PREFIX + "docker-plugin";
+
+  @Private
+  public static final String NVIDIA_DOCKER_V1 = "nvidia-docker-v1";
+
+  @Private
+  public static final String DEFAULT_NM_GPU_DOCKER_PLUGIN_IMPL =
+      NVIDIA_DOCKER_V1;
+
+  /**
+   * This setting controls end point of nvidia-docker-v1 plugin
+   */
+  @Private
+  public static final String NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT =
+      NM_GPU_RESOURCE_PREFIX + "docker-plugin." + NVIDIA_DOCKER_V1
+          + ".endpoint";
+
+  @Private
+  public static final String DEFAULT_NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT =
+      "http://localhost:3476/v1.0/docker/cli";
+
 
 
   /** NM Webapp address.**/
   /** NM Webapp address.**/
   public static final String NM_WEBAPP_ADDRESS = NM_PREFIX + "webapp.address";
   public static final String NM_WEBAPP_ADDRESS = NM_PREFIX + "webapp.address";

+ 34 - 8
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/resources/yarn-default.xml

@@ -3449,6 +3449,15 @@
     <value>/confstore</value>
     <value>/confstore</value>
   </property>
   </property>
 
 
+  <property>
+    <description>
+      Provides an option for client to load supported resource types from RM
+      instead of depending on local resource-types.xml file.
+    </description>
+    <name>yarn.client.load.resource-types.from-server</name>
+    <value>false</value>
+  </property>
+
   <property>
   <property>
     <description>
     <description>
       When yarn.nodemanager.resource.gpu.allowed-gpu-devices=auto specified,
       When yarn.nodemanager.resource.gpu.allowed-gpu-devices=auto specified,
@@ -3477,12 +3486,18 @@
       Number of GPU devices will be reported to RM to make scheduling decisions.
       Number of GPU devices will be reported to RM to make scheduling decisions.
       Set to auto (default) let YARN automatically discover GPU resource from
       Set to auto (default) let YARN automatically discover GPU resource from
       system.
       system.
+
       Manually specify GPU devices if auto detect GPU device failed or admin
       Manually specify GPU devices if auto detect GPU device failed or admin
       only want subset of GPU devices managed by YARN. GPU device is identified
       only want subset of GPU devices managed by YARN. GPU device is identified
-      by their minor device number. A common approach to get minor device number
-      of GPUs is using "nvidia-smi -q" and search "Minor Number" output. An
-      example of manual specification is "0,1,2,4" to allow YARN NodeManager
-      to manage GPU devices with minor number 0/1/2/4.
+      by their minor device number and index. A common approach to get minor
+      device number of GPUs is using "nvidia-smi -q" and search "Minor Number"
+      output.
+
+      When manual specify minor numbers, admin needs to include indice of GPUs
+      as well, format is index:minor_number[,index:minor_number...]. An example
+      of manual specification is "0:0,1:1,2:2,3:4" to allow YARN NodeManager to
+      manage GPU devices with indice 0/1/2/3 and minor number 0/1/2/4.
+      numbers .
     </description>
     </description>
     <name>yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices</name>
     <name>yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices</name>
     <value>auto</value>
     <value>auto</value>
@@ -3490,11 +3505,22 @@
 
 
   <property>
   <property>
     <description>
     <description>
-      Provides an option for client to load supported resource types from RM
-      instead of depending on local resource-types.xml file.
+      Specify docker command plugin for GPU. By default uses Nvidia docker V1.
     </description>
     </description>
-    <name>yarn.client.load.resource-types.from-server</name>
-    <value>false</value>
+    <name>yarn.nodemanager.resource-plugins.gpu.docker-plugin</name>
+    <value>nvidia-docker-v1</value>
+  </property>
+
+  <property>
+    <description>
+      Specify end point of nvidia-docker-plugin.
+      Please find documentation: https://github.com/NVIDIA/nvidia-docker/wiki
+      For more details.
+    </description>
+    <name>yarn.nodemanager.resource-plugins.gpu.docker-plugin.nvidia-docker-v1.endpoint</name>
+    <value>http://localhost:3476/v1.0/docker/cli</value>
   </property>
   </property>
 
 
+>>>>>>> theirs
+
 </configuration>
 </configuration>

+ 1 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/LinuxContainerExecutor.java

@@ -20,7 +20,6 @@ package org.apache.hadoop.yarn.server.nodemanager;
 
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Optional;
 import com.google.common.base.Optional;
-import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerChain;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
@@ -325,7 +324,7 @@ public class LinuxContainerExecutor extends ContainerExecutor {
       if (linuxContainerRuntime == null) {
       if (linuxContainerRuntime == null) {
         LinuxContainerRuntime runtime = new DelegatingLinuxContainerRuntime();
         LinuxContainerRuntime runtime = new DelegatingLinuxContainerRuntime();
 
 
-        runtime.initialize(conf);
+        runtime.initialize(conf, nmContext);
         this.linuxContainerRuntime = runtime;
         this.linuxContainerRuntime = runtime;
       }
       }
     } catch (ContainerExecutionException e) {
     } catch (ContainerExecutionException e) {

+ 45 - 57
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/GpuResourceAllocator.java

@@ -26,12 +26,11 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.api.records.Resource;
-import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
 import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
 import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
-import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.io.Serializable;
 import java.io.Serializable;
@@ -54,8 +53,8 @@ import static org.apache.hadoop.yarn.api.records.ResourceInformation.GPU_URI;
 public class GpuResourceAllocator {
 public class GpuResourceAllocator {
   final static Log LOG = LogFactory.getLog(GpuResourceAllocator.class);
   final static Log LOG = LogFactory.getLog(GpuResourceAllocator.class);
 
 
-  private Set<Integer> allowedGpuDevices = new TreeSet<>();
-  private Map<Integer, ContainerId> usedDevices = new TreeMap<>();
+  private Set<GpuDevice> allowedGpuDevices = new TreeSet<>();
+  private Map<GpuDevice, ContainerId> usedDevices = new TreeMap<>();
   private Context nmContext;
   private Context nmContext;
 
 
   public GpuResourceAllocator(Context ctx) {
   public GpuResourceAllocator(Context ctx) {
@@ -63,14 +62,14 @@ public class GpuResourceAllocator {
   }
   }
 
 
   /**
   /**
-   * Contains allowed and denied devices with minor number.
+   * Contains allowed and denied devices
    * Denied devices will be useful for cgroups devices module to do blacklisting
    * Denied devices will be useful for cgroups devices module to do blacklisting
    */
    */
   static class GpuAllocation {
   static class GpuAllocation {
-    private Set<Integer> allowed = Collections.emptySet();
-    private Set<Integer> denied = Collections.emptySet();
+    private Set<GpuDevice> allowed = Collections.emptySet();
+    private Set<GpuDevice> denied = Collections.emptySet();
 
 
-    GpuAllocation(Set<Integer> allowed, Set<Integer> denied) {
+    GpuAllocation(Set<GpuDevice> allowed, Set<GpuDevice> denied) {
       if (allowed != null) {
       if (allowed != null) {
         this.allowed = ImmutableSet.copyOf(allowed);
         this.allowed = ImmutableSet.copyOf(allowed);
       }
       }
@@ -79,21 +78,21 @@ public class GpuResourceAllocator {
       }
       }
     }
     }
 
 
-    public Set<Integer> getAllowedGPUs() {
+    public Set<GpuDevice> getAllowedGPUs() {
       return allowed;
       return allowed;
     }
     }
 
 
-    public Set<Integer> getDeniedGPUs() {
+    public Set<GpuDevice> getDeniedGPUs() {
       return denied;
       return denied;
     }
     }
   }
   }
 
 
   /**
   /**
    * Add GPU to allowed list
    * Add GPU to allowed list
-   * @param minorNumber minor number of the GPU device.
+   * @param gpuDevice gpu device
    */
    */
-  public synchronized void addGpu(int minorNumber) {
-    allowedGpuDevices.add(minorNumber);
+  public synchronized void addGpu(GpuDevice gpuDevice) {
+    allowedGpuDevices.add(gpuDevice);
   }
   }
 
 
   private String getResourceHandlerExceptionMessage(int numRequestedGpuDevices,
   private String getResourceHandlerExceptionMessage(int numRequestedGpuDevices,
@@ -117,42 +116,42 @@ public class GpuResourceAllocator {
               + containerId);
               + containerId);
     }
     }
 
 
-    for (Serializable deviceId : c.getResourceMappings().getAssignedResources(
-        GPU_URI)){
-      if (!(deviceId instanceof String)) {
+    for (Serializable gpuDeviceSerializable : c.getResourceMappings()
+        .getAssignedResources(GPU_URI)) {
+      if (!(gpuDeviceSerializable instanceof GpuDevice)) {
         throw new ResourceHandlerException(
         throw new ResourceHandlerException(
             "Trying to recover device id, however it"
             "Trying to recover device id, however it"
-                + " is not String, this shouldn't happen");
+                + " is not GpuDevice, this shouldn't happen");
       }
       }
 
 
-
-      int devId;
-      try {
-        devId = Integer.parseInt((String)deviceId);
-      } catch (NumberFormatException e) {
-        throw new ResourceHandlerException("Failed to recover device id because"
-            + "it is not a valid integer, devId:" + deviceId);
-      }
+      GpuDevice gpuDevice = (GpuDevice) gpuDeviceSerializable;
 
 
       // Make sure it is in allowed GPU device.
       // Make sure it is in allowed GPU device.
-      if (!allowedGpuDevices.contains(devId)) {
-        throw new ResourceHandlerException("Try to recover device id = " + devId
-            + " however it is not in allowed device list:" + StringUtils
-            .join(",", allowedGpuDevices));
+      if (!allowedGpuDevices.contains(gpuDevice)) {
+        throw new ResourceHandlerException(
+            "Try to recover device = " + gpuDevice
+                + " however it is not in allowed device list:" + StringUtils
+                .join(",", allowedGpuDevices));
       }
       }
 
 
       // Make sure it is not occupied by anybody else
       // Make sure it is not occupied by anybody else
-      if (usedDevices.containsKey(devId)) {
-        throw new ResourceHandlerException("Try to recover device id = " + devId
-            + " however it is already assigned to container=" + usedDevices
-            .get(devId) + ", please double check what happened.");
+      if (usedDevices.containsKey(gpuDevice)) {
+        throw new ResourceHandlerException(
+            "Try to recover device id = " + gpuDevice
+                + " however it is already assigned to container=" + usedDevices
+                .get(gpuDevice) + ", please double check what happened.");
       }
       }
 
 
-      usedDevices.put(devId, containerId);
+      usedDevices.put(gpuDevice, containerId);
     }
     }
   }
   }
 
 
-  private int getRequestedGpus(Resource requestedResource) {
+  /**
+   * Get number of requested GPUs from resource.
+   * @param requestedResource requested resource
+   * @return #gpus.
+   */
+  public static int getRequestedGpus(Resource requestedResource) {
     try {
     try {
       return Long.valueOf(requestedResource.getResourceValue(
       return Long.valueOf(requestedResource.getResourceValue(
           GPU_URI)).intValue();
           GPU_URI)).intValue();
@@ -164,8 +163,8 @@ public class GpuResourceAllocator {
   /**
   /**
    * Assign GPU to requestor
    * Assign GPU to requestor
    * @param container container to allocate
    * @param container container to allocate
-   * @return List of denied Gpus with minor numbers
-   * @throws ResourceHandlerException When failed to
+   * @return allocation results.
+   * @throws ResourceHandlerException When failed to assign GPUs.
    */
    */
   public synchronized GpuAllocation assignGpus(Container container)
   public synchronized GpuAllocation assignGpus(Container container)
       throws ResourceHandlerException {
       throws ResourceHandlerException {
@@ -180,12 +179,12 @@ public class GpuResourceAllocator {
                 containerId));
                 containerId));
       }
       }
 
 
-      Set<Integer> assignedGpus = new HashSet<>();
+      Set<GpuDevice> assignedGpus = new TreeSet<>();
 
 
-      for (int deviceNum : allowedGpuDevices) {
-        if (!usedDevices.containsKey(deviceNum)) {
-          usedDevices.put(deviceNum, containerId);
-          assignedGpus.add(deviceNum);
+      for (GpuDevice gpu : allowedGpuDevices) {
+        if (!usedDevices.containsKey(gpu)) {
+          usedDevices.put(gpu, containerId);
+          assignedGpus.add(gpu);
           if (assignedGpus.size() == numRequestedGpuDevices) {
           if (assignedGpus.size() == numRequestedGpuDevices) {
             break;
             break;
           }
           }
@@ -194,21 +193,10 @@ public class GpuResourceAllocator {
 
 
       // Record in state store if we allocated anything
       // Record in state store if we allocated anything
       if (!assignedGpus.isEmpty()) {
       if (!assignedGpus.isEmpty()) {
-        List<Serializable> allocatedDevices = new ArrayList<>();
-        for (int gpu : assignedGpus) {
-          allocatedDevices.add(String.valueOf(gpu));
-        }
         try {
         try {
-          // Update Container#getResourceMapping.
-          ResourceMappings.AssignedResources assignedResources =
-              new ResourceMappings.AssignedResources();
-          assignedResources.updateAssignedResources(allocatedDevices);
-          container.getResourceMappings().addAssignedResources(GPU_URI,
-              assignedResources);
-
           // Update state store.
           // Update state store.
-          nmContext.getNMStateStore().storeAssignedResources(containerId,
-              GPU_URI, allocatedDevices);
+          nmContext.getNMStateStore().storeAssignedResources(container, GPU_URI,
+              new ArrayList<>(assignedGpus));
         } catch (IOException e) {
         } catch (IOException e) {
           cleanupAssignGpus(containerId);
           cleanupAssignGpus(containerId);
           throw new ResourceHandlerException(e);
           throw new ResourceHandlerException(e);
@@ -226,7 +214,7 @@ public class GpuResourceAllocator {
    * @param containerId containerId
    * @param containerId containerId
    */
    */
   public synchronized void cleanupAssignGpus(ContainerId containerId) {
   public synchronized void cleanupAssignGpus(ContainerId containerId) {
-    Iterator<Map.Entry<Integer, ContainerId>> iter =
+    Iterator<Map.Entry<GpuDevice, ContainerId>> iter =
         usedDevices.entrySet().iterator();
         usedDevices.entrySet().iterator();
     while (iter.hasNext()) {
     while (iter.hasNext()) {
       if (iter.next().getValue().equals(containerId)) {
       if (iter.next().getValue().equals(containerId)) {
@@ -236,7 +224,7 @@ public class GpuResourceAllocator {
   }
   }
 
 
   @VisibleForTesting
   @VisibleForTesting
-  public synchronized Map<Integer, ContainerId> getDeviceAllocationMapping() {
+  public synchronized Map<GpuDevice, ContainerId> getDeviceAllocationMapping() {
      return new HashMap<>(usedDevices);
      return new HashMap<>(usedDevices);
   }
   }
 }
 }

+ 59 - 31
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/GpuResourceHandlerImpl.java

@@ -24,8 +24,6 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
-import org.apache.hadoop.yarn.api.records.ResourceInformation;
-import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
@@ -35,6 +33,8 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileg
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.DockerLinuxContainerRuntime;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -64,17 +64,23 @@ public class GpuResourceHandlerImpl implements ResourceHandler {
   @Override
   @Override
   public List<PrivilegedOperation> bootstrap(Configuration configuration)
   public List<PrivilegedOperation> bootstrap(Configuration configuration)
       throws ResourceHandlerException {
       throws ResourceHandlerException {
-    List<Integer> minorNumbersOfUsableGpus;
+    List<GpuDevice> usableGpus;
     try {
     try {
-      minorNumbersOfUsableGpus = GpuDiscoverer.getInstance()
-          .getMinorNumbersOfGpusUsableByYarn();
+      usableGpus = GpuDiscoverer.getInstance()
+          .getGpusUsableByYarn();
+      if (usableGpus == null || usableGpus.isEmpty()) {
+        String message = "GPU is enabled on the NodeManager, but couldn't find "
+            + "any usable GPU devices, please double check configuration.";
+        LOG.error(message);
+        throw new ResourceHandlerException(message);
+      }
     } catch (YarnException e) {
     } catch (YarnException e) {
       LOG.error("Exception when trying to get usable GPU device", e);
       LOG.error("Exception when trying to get usable GPU device", e);
       throw new ResourceHandlerException(e);
       throw new ResourceHandlerException(e);
     }
     }
 
 
-    for (int minorNumber : minorNumbersOfUsableGpus) {
-      gpuAllocator.addGpu(minorNumber);
+    for (GpuDevice gpu : usableGpus) {
+      gpuAllocator.addGpu(gpu);
     }
     }
 
 
     // And initialize cgroups
     // And initialize cgroups
@@ -96,33 +102,55 @@ public class GpuResourceHandlerImpl implements ResourceHandler {
     // Create device cgroups for the container
     // Create device cgroups for the container
     cGroupsHandler.createCGroup(CGroupsHandler.CGroupController.DEVICES,
     cGroupsHandler.createCGroup(CGroupsHandler.CGroupController.DEVICES,
         containerIdStr);
         containerIdStr);
-    try {
-      // Execute c-e to setup GPU isolation before launch the container
-      PrivilegedOperation privilegedOperation = new PrivilegedOperation(
-          PrivilegedOperation.OperationType.GPU, Arrays
-          .asList(CONTAINER_ID_CLI_OPTION, containerIdStr));
-      if (!allocation.getDeniedGPUs().isEmpty()) {
-        privilegedOperation.appendArgs(Arrays.asList(EXCLUDED_GPUS_CLI_OPTION,
-            StringUtils.join(",", allocation.getDeniedGPUs())));
+    if (!DockerLinuxContainerRuntime.isDockerContainerRequested(
+        container.getLaunchContext().getEnvironment())) {
+      // Write to devices cgroup only for non-docker container. The reason is
+      // docker engine runtime runc do the devices cgroups initialize in the
+      // pre-hook, see:
+      //   https://github.com/opencontainers/runc/blob/master/libcontainer/configs/device_defaults.go
+      //
+      // YARN by default runs docker container inside cgroup, if we setup cgroups
+      // devices.deny for the parent cgroup for launched container, we can see
+      // errors like: failed to write c *:* m to devices.allow:
+      // write path-to-parent-cgroup/<container-id>/devices.allow:
+      // operation not permitted.
+      //
+      // To avoid this happen, if docker is requested when container being
+      // launched, we will not setup devices.deny for the container. Instead YARN
+      // will pass --device parameter to docker engine. See NvidiaDockerV1CommandPlugin
+      try {
+        // Execute c-e to setup GPU isolation before launch the container
+        PrivilegedOperation privilegedOperation = new PrivilegedOperation(
+            PrivilegedOperation.OperationType.GPU,
+            Arrays.asList(CONTAINER_ID_CLI_OPTION, containerIdStr));
+        if (!allocation.getDeniedGPUs().isEmpty()) {
+          List<Integer> minorNumbers = new ArrayList<>();
+          for (GpuDevice deniedGpu : allocation.getDeniedGPUs()) {
+            minorNumbers.add(deniedGpu.getMinorNumber());
+          }
+          privilegedOperation.appendArgs(Arrays.asList(EXCLUDED_GPUS_CLI_OPTION,
+              StringUtils.join(",", minorNumbers)));
+        }
+
+        privilegedOperationExecutor.executePrivilegedOperation(
+            privilegedOperation, true);
+      } catch (PrivilegedOperationException e) {
+        cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
+            containerIdStr);
+        LOG.warn("Could not update cgroup for container", e);
+        throw new ResourceHandlerException(e);
       }
       }
 
 
-      privilegedOperationExecutor.executePrivilegedOperation(
-          privilegedOperation, true);
-    } catch (PrivilegedOperationException e) {
-      cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
-          containerIdStr);
-      LOG.warn("Could not update cgroup for container", e);
-      throw new ResourceHandlerException(e);
-    }
+      List<PrivilegedOperation> ret = new ArrayList<>();
+      ret.add(new PrivilegedOperation(
+          PrivilegedOperation.OperationType.ADD_PID_TO_CGROUP,
+          PrivilegedOperation.CGROUP_ARG_PREFIX + cGroupsHandler
+              .getPathForCGroupTasks(CGroupsHandler.CGroupController.DEVICES,
+                  containerIdStr)));
 
 
-    List<PrivilegedOperation> ret = new ArrayList<>();
-    ret.add(new PrivilegedOperation(
-        PrivilegedOperation.OperationType.ADD_PID_TO_CGROUP,
-        PrivilegedOperation.CGROUP_ARG_PREFIX
-            + cGroupsHandler.getPathForCGroupTasks(
-            CGroupsHandler.CGroupController.DEVICES, containerIdStr)));
-
-    return ret;
+      return ret;
+    }
+    return null;
   }
   }
 
 
   @VisibleForTesting
   @VisibleForTesting

+ 2 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DefaultLinuxContainerRuntime.java

@@ -25,6 +25,7 @@ import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.server.nodemanager.ContainerExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.ContainerExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
@@ -67,7 +68,7 @@ public class DefaultLinuxContainerRuntime implements LinuxContainerRuntime {
   }
   }
 
 
   @Override
   @Override
-  public void initialize(Configuration conf)
+  public void initialize(Configuration conf, Context nmContext)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
     this.conf = conf;
     this.conf = conf;
   }
   }

+ 5 - 4
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DelegatingLinuxContainerRuntime.java

@@ -25,6 +25,7 @@ import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
@@ -57,7 +58,7 @@ public class DelegatingLinuxContainerRuntime implements LinuxContainerRuntime {
       EnumSet.noneOf(LinuxContainerRuntimeConstants.RuntimeType.class);
       EnumSet.noneOf(LinuxContainerRuntimeConstants.RuntimeType.class);
 
 
   @Override
   @Override
-  public void initialize(Configuration conf)
+  public void initialize(Configuration conf, Context nmContext)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
     String[] configuredRuntimes = conf.getTrimmedStrings(
     String[] configuredRuntimes = conf.getTrimmedStrings(
         YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
@@ -77,19 +78,19 @@ public class DelegatingLinuxContainerRuntime implements LinuxContainerRuntime {
         LinuxContainerRuntimeConstants.RuntimeType.JAVASANDBOX)) {
         LinuxContainerRuntimeConstants.RuntimeType.JAVASANDBOX)) {
       javaSandboxLinuxContainerRuntime = new JavaSandboxLinuxContainerRuntime(
       javaSandboxLinuxContainerRuntime = new JavaSandboxLinuxContainerRuntime(
           PrivilegedOperationExecutor.getInstance(conf));
           PrivilegedOperationExecutor.getInstance(conf));
-      javaSandboxLinuxContainerRuntime.initialize(conf);
+      javaSandboxLinuxContainerRuntime.initialize(conf, nmContext);
     }
     }
     if (isRuntimeAllowed(
     if (isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER)) {
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER)) {
       dockerLinuxContainerRuntime = new DockerLinuxContainerRuntime(
       dockerLinuxContainerRuntime = new DockerLinuxContainerRuntime(
           PrivilegedOperationExecutor.getInstance(conf));
           PrivilegedOperationExecutor.getInstance(conf));
-      dockerLinuxContainerRuntime.initialize(conf);
+      dockerLinuxContainerRuntime.initialize(conf, nmContext);
     }
     }
     if (isRuntimeAllowed(
     if (isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT)) {
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT)) {
       defaultLinuxContainerRuntime = new DefaultLinuxContainerRuntime(
       defaultLinuxContainerRuntime = new DefaultLinuxContainerRuntime(
           PrivilegedOperationExecutor.getInstance(conf));
           PrivilegedOperationExecutor.getInstance(conf));
-      defaultLinuxContainerRuntime.initialize(conf);
+      defaultLinuxContainerRuntime.initialize(conf, nmContext);
     }
     }
   }
   }
 
 

+ 86 - 5
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/DockerLinuxContainerRuntime.java

@@ -21,6 +21,10 @@
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime;
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime;
 
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePlugin;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceAudience;
@@ -172,6 +176,7 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
       "YARN_CONTAINER_RUNTIME_DOCKER_LOCAL_RESOURCE_MOUNTS";
       "YARN_CONTAINER_RUNTIME_DOCKER_LOCAL_RESOURCE_MOUNTS";
 
 
   private Configuration conf;
   private Configuration conf;
+  private Context nmContext;
   private DockerClient dockerClient;
   private DockerClient dockerClient;
   private PrivilegedOperationExecutor privilegedOperationExecutor;
   private PrivilegedOperationExecutor privilegedOperationExecutor;
   private Set<String> allowedNetworks = new HashSet<>();
   private Set<String> allowedNetworks = new HashSet<>();
@@ -220,14 +225,14 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
    * Create an instance using the given {@link PrivilegedOperationExecutor}
    * Create an instance using the given {@link PrivilegedOperationExecutor}
    * instance for performing operations and the given {@link CGroupsHandler}
    * instance for performing operations and the given {@link CGroupsHandler}
    * instance. This constructor is intended for use in testing.
    * instance. This constructor is intended for use in testing.
-   *
-   * @param privilegedOperationExecutor the {@link PrivilegedOperationExecutor}
+   *  @param privilegedOperationExecutor the {@link PrivilegedOperationExecutor}
    * instance
    * instance
    * @param cGroupsHandler the {@link CGroupsHandler} instance
    * @param cGroupsHandler the {@link CGroupsHandler} instance
    */
    */
   @VisibleForTesting
   @VisibleForTesting
-  public DockerLinuxContainerRuntime(PrivilegedOperationExecutor
-      privilegedOperationExecutor, CGroupsHandler cGroupsHandler) {
+  public DockerLinuxContainerRuntime(
+      PrivilegedOperationExecutor privilegedOperationExecutor,
+      CGroupsHandler cGroupsHandler) {
     this.privilegedOperationExecutor = privilegedOperationExecutor;
     this.privilegedOperationExecutor = privilegedOperationExecutor;
 
 
     if (cGroupsHandler == null) {
     if (cGroupsHandler == null) {
@@ -239,8 +244,9 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
   }
   }
 
 
   @Override
   @Override
-  public void initialize(Configuration conf)
+  public void initialize(Configuration conf, Context nmContext)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
+    this.nmContext = nmContext;
     this.conf = conf;
     this.conf = conf;
     dockerClient = new DockerClient(conf);
     dockerClient = new DockerClient(conf);
     allowedNetworks.clear();
     allowedNetworks.clear();
@@ -288,9 +294,54 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
     return false;
     return false;
   }
   }
 
 
+  private void runDockerVolumeCommand(DockerVolumeCommand dockerVolumeCommand,
+      Container container) throws ContainerExecutionException {
+    try {
+      String commandFile = dockerClient.writeCommandToTempFile(
+          dockerVolumeCommand, container.getContainerId().toString());
+      PrivilegedOperation privOp = new PrivilegedOperation(
+          PrivilegedOperation.OperationType.RUN_DOCKER_CMD);
+      privOp.appendArgs(commandFile);
+      String output = privilegedOperationExecutor
+          .executePrivilegedOperation(null, privOp, null,
+              null, true, false);
+      LOG.info("ContainerId=" + container.getContainerId()
+          + ", docker volume output for " + dockerVolumeCommand + ": "
+          + output);
+    } catch (ContainerExecutionException e) {
+      LOG.error("Error when writing command to temp file, command="
+              + dockerVolumeCommand,
+          e);
+      throw e;
+    } catch (PrivilegedOperationException e) {
+      LOG.error("Error when executing command, command="
+          + dockerVolumeCommand, e);
+      throw new ContainerExecutionException(e);
+    }
+
+  }
+
   @Override
   @Override
   public void prepareContainer(ContainerRuntimeContext ctx)
   public void prepareContainer(ContainerRuntimeContext ctx)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
+    Container container = ctx.getContainer();
+
+    // Create volumes when needed.
+    if (nmContext != null
+        && nmContext.getResourcePluginManager().getNameToPlugins() != null) {
+      for (ResourcePlugin plugin : nmContext.getResourcePluginManager()
+          .getNameToPlugins().values()) {
+        DockerCommandPlugin dockerCommandPlugin =
+            plugin.getDockerCommandPluginInstance();
+        if (dockerCommandPlugin != null) {
+          DockerVolumeCommand dockerVolumeCommand =
+              dockerCommandPlugin.getCreateDockerVolumeCommand(ctx.getContainer());
+          if (dockerVolumeCommand != null) {
+            runDockerVolumeCommand(dockerVolumeCommand, container);
+          }
+        }
+      }
+    }
   }
   }
 
 
   private void validateContainerNetworkType(String network)
   private void validateContainerNetworkType(String network)
@@ -623,6 +674,19 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
       runCommand.groupAdd(groups);
       runCommand.groupAdd(groups);
     }
     }
 
 
+    // use plugins to update docker run command.
+    if (nmContext != null
+        && nmContext.getResourcePluginManager().getNameToPlugins() != null) {
+      for (ResourcePlugin plugin : nmContext.getResourcePluginManager()
+          .getNameToPlugins().values()) {
+        DockerCommandPlugin dockerCommandPlugin =
+            plugin.getDockerCommandPluginInstance();
+        if (dockerCommandPlugin != null) {
+          dockerCommandPlugin.updateDockerRunCommand(runCommand, container);
+        }
+      }
+    }
+
     String commandFile = dockerClient.writeCommandToTempFile(runCommand,
     String commandFile = dockerClient.writeCommandToTempFile(runCommand,
         containerIdStr);
         containerIdStr);
     PrivilegedOperation launchOp = buildLaunchOp(ctx,
     PrivilegedOperation launchOp = buildLaunchOp(ctx,
@@ -683,6 +747,23 @@ public class DockerLinuxContainerRuntime implements LinuxContainerRuntime {
   @Override
   @Override
   public void reapContainer(ContainerRuntimeContext ctx)
   public void reapContainer(ContainerRuntimeContext ctx)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
+    // Cleanup volumes when needed.
+    if (nmContext != null
+        && nmContext.getResourcePluginManager().getNameToPlugins() != null) {
+      for (ResourcePlugin plugin : nmContext.getResourcePluginManager()
+          .getNameToPlugins().values()) {
+        DockerCommandPlugin dockerCommandPlugin =
+            plugin.getDockerCommandPluginInstance();
+        if (dockerCommandPlugin != null) {
+          DockerVolumeCommand dockerVolumeCommand =
+              dockerCommandPlugin.getCleanupDockerVolumesCommand(
+                  ctx.getContainer());
+          if (dockerVolumeCommand != null) {
+            runDockerVolumeCommand(dockerVolumeCommand, ctx.getContainer());
+          }
+        }
+      }
+    }
   }
   }
 
 
 
 

+ 3 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/JavaSandboxLinuxContainerRuntime.java

@@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.IOUtils;
 import org.apache.hadoop.io.IOUtils;
 import org.apache.hadoop.security.Groups;
 import org.apache.hadoop.security.Groups;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeContext;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeContext;
@@ -143,7 +144,7 @@ public class JavaSandboxLinuxContainerRuntime
   }
   }
 
 
   @Override
   @Override
-  public void initialize(Configuration conf)
+  public void initialize(Configuration conf, Context nmContext)
       throws ContainerExecutionException {
       throws ContainerExecutionException {
     this.configuration = conf;
     this.configuration = conf;
     this.sandboxMode =
     this.sandboxMode =
@@ -151,7 +152,7 @@ public class JavaSandboxLinuxContainerRuntime
             this.configuration.get(YARN_CONTAINER_SANDBOX,
             this.configuration.get(YARN_CONTAINER_SANDBOX,
                 YarnConfiguration.DEFAULT_YARN_CONTAINER_SANDBOX));
                 YarnConfiguration.DEFAULT_YARN_CONTAINER_SANDBOX));
 
 
-    super.initialize(conf);
+    super.initialize(conf, nmContext);
   }
   }
 
 
   /**
   /**

+ 3 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/LinuxContainerRuntime.java

@@ -23,6 +23,7 @@ package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntime;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntime;
 
 
@@ -38,9 +39,10 @@ public interface LinuxContainerRuntime extends ContainerRuntime {
    * Initialize the runtime.
    * Initialize the runtime.
    *
    *
    * @param conf the {@link Configuration} to use
    * @param conf the {@link Configuration} to use
+   * @param nmContext NMContext
    * @throws ContainerExecutionException if an error occurs while initializing
    * @throws ContainerExecutionException if an error occurs while initializing
    * the runtime
    * the runtime
    */
    */
-  void initialize(Configuration conf) throws ContainerExecutionException;
+  void initialize(Configuration conf, Context nmContext) throws ContainerExecutionException;
 }
 }
 
 

+ 5 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/DockerRunCommand.java

@@ -76,6 +76,11 @@ public class DockerRunCommand extends DockerCommand {
     return this;
     return this;
   }
   }
 
 
+  public DockerRunCommand setVolumeDriver(String volumeDriver) {
+    super.addCommandArguments("volume-driver", volumeDriver);
+    return this;
+  }
+
   public DockerRunCommand setCGroupParent(String parentPath) {
   public DockerRunCommand setCGroupParent(String parentPath) {
     super.addCommandArguments("cgroup-parent", parentPath);
     super.addCommandArguments("cgroup-parent", parentPath);
     return this;
     return this;

+ 49 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/DockerVolumeCommand.java

@@ -0,0 +1,49 @@
+/*
+ * *
+ *  Licensed to the Apache Software Foundation (ASF) under one
+ *  or more contributor license agreements.  See the NOTICE file
+ *  distributed with this work for additional information
+ *  regarding copyright ownership.  The ASF licenses this file
+ *  to you under the Apache License, Version 2.0 (the
+ *  "License"); you may not use this file except in compliance
+ *  with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ * /
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker;
+
+import java.util.regex.Pattern;
+
+/**
+ * Docker Volume Command, run "docker volume --help" for more details.
+ */
+public class DockerVolumeCommand extends DockerCommand {
+  public static final String VOLUME_COMMAND = "volume";
+  public static final String VOLUME_CREATE_COMMAND = "create";
+  // Regex pattern for volume name
+  public static final Pattern VOLUME_NAME_PATTERN = Pattern.compile(
+      "[a-zA-Z0-9][a-zA-Z0-9_.-]*");
+
+  public DockerVolumeCommand(String subCommand) {
+    super(VOLUME_COMMAND);
+    super.addCommandArguments("sub-command", subCommand);
+  }
+
+  public DockerVolumeCommand setVolumeName(String volumeName) {
+    super.addCommandArguments("volume", volumeName);
+    return this;
+  }
+
+  public DockerVolumeCommand setDriverName(String driverName) {
+    super.addCommandArguments("driver", driverName);
+    return this;
+  }
+}

+ 59 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/DockerCommandPlugin.java

@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin;
+
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
+
+/**
+ * Interface to make different resource plugins (e.g. GPU) can update docker run
+ * command without adding logic to Docker runtime.
+ */
+public interface DockerCommandPlugin {
+  /**
+   * Update docker run command
+   * @param dockerRunCommand docker run command
+   * @param container NM container
+   * @throws ContainerExecutionException if any issue occurs
+   */
+  void updateDockerRunCommand(DockerRunCommand dockerRunCommand,
+      Container container) throws ContainerExecutionException;
+
+  /**
+   * Create volume when needed.
+   * @param container container
+   * @return {@link DockerVolumeCommand} to create volume
+   * @throws ContainerExecutionException when any issue happens
+   */
+  DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
+      throws ContainerExecutionException;
+
+  /**
+   * Cleanup volumes created for one docker container
+   * @param container container
+   * @return {@link DockerVolumeCommand} to remove volume
+   * @throws ContainerExecutionException when any issue happens
+   */
+  DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
+      throws ContainerExecutionException;
+
+  // Add support to other docker command when required.
+}

+ 11 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/ResourcePlugin.java

@@ -24,6 +24,7 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileg
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerChain;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerChain;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.DockerLinuxContainerRuntime;
 
 
 /**
 /**
  * {@link ResourcePlugin} is an interface for node manager to easier support
  * {@link ResourcePlugin} is an interface for node manager to easier support
@@ -80,4 +81,14 @@ public interface ResourcePlugin {
    * @throws YarnException if any issue occurs
    * @throws YarnException if any issue occurs
    */
    */
   void cleanup() throws YarnException;
   void cleanup() throws YarnException;
+
+  /**
+   * Plugin need to get {@link DockerCommandPlugin}. This will be invoked by
+   * {@link DockerLinuxContainerRuntime} when execute docker commands such as
+   * run/stop/pull, etc.
+   *
+   * @return DockerCommandPlugin instance. return null if plugin doesn't
+   *         have requirement to update docker command.
+   */
+  DockerCommandPlugin getDockerCommandPluginInstance();
 }
 }

+ 78 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java

@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;
+
+import java.io.Serializable;
+
+/**
+ * This class is used to represent GPU device while allocation.
+ */
+public class GpuDevice implements Serializable, Comparable {
+  private int index;
+  private int minorNumber;
+  private static final long serialVersionUID = -6812314470754667710L;
+
+  public GpuDevice(int index, int minorNumber) {
+    this.index = index;
+    this.minorNumber = minorNumber;
+  }
+
+  public int getIndex() {
+    return index;
+  }
+
+  public int getMinorNumber() {
+    return minorNumber;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (obj == null || !(obj instanceof GpuDevice)) {
+      return false;
+    }
+    GpuDevice other = (GpuDevice) obj;
+    return index == other.index && minorNumber == other.minorNumber;
+  }
+
+  @Override
+  public int compareTo(Object obj) {
+    if (obj == null || (!(obj instanceof  GpuDevice))) {
+      return -1;
+    }
+
+    GpuDevice other = (GpuDevice) obj;
+
+    int result = Integer.compare(index, other.index);
+    if (0 != result) {
+      return result;
+    }
+    return Integer.compare(minorNumber, other.minorNumber);
+  }
+
+  @Override
+  public int hashCode() {
+    final int prime = 47;
+    return prime * index + minorNumber;
+  }
+
+  @Override
+  public String toString() {
+    return "(index=" + index + ",minor_number=" + minorNumber + ")";
+  }
+}

+ 20 - 10
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java

@@ -136,12 +136,12 @@ public class GpuDiscoverer {
   }
   }
 
 
   /**
   /**
-   * Get list of minor device numbers of Gpu devices usable by YARN.
+   * Get list of GPU devices usable by YARN.
    *
    *
-   * @return List of minor device numbers of Gpu devices.
+   * @return List of GPU devices
    * @throws YarnException when any issue happens
    * @throws YarnException when any issue happens
    */
    */
-  public synchronized List<Integer> getMinorNumbersOfGpusUsableByYarn()
+  public synchronized List<GpuDevice> getGpusUsableByYarn()
       throws YarnException {
       throws YarnException {
     validateConfOrThrowException();
     validateConfOrThrowException();
 
 
@@ -149,7 +149,7 @@ public class GpuDiscoverer {
         YarnConfiguration.NM_GPU_ALLOWED_DEVICES,
         YarnConfiguration.NM_GPU_ALLOWED_DEVICES,
         YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES);
         YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES);
 
 
-    List<Integer> minorNumbers = new ArrayList<>();
+    List<GpuDevice> gpuDevices = new ArrayList<>();
 
 
     if (allowedDevicesStr.equals(
     if (allowedDevicesStr.equals(
         YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES)) {
         YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES)) {
@@ -167,21 +167,31 @@ public class GpuDiscoverer {
       }
       }
 
 
       if (lastDiscoveredGpuInformation.getGpus() != null) {
       if (lastDiscoveredGpuInformation.getGpus() != null) {
-        for (PerGpuDeviceInformation gpu : lastDiscoveredGpuInformation
-            .getGpus()) {
-          minorNumbers.add(gpu.getMinorNumber());
+        for (int i = 0; i < lastDiscoveredGpuInformation.getGpus().size();
+             i++) {
+          List<PerGpuDeviceInformation> gpuInfos =
+              lastDiscoveredGpuInformation.getGpus();
+          gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));
         }
         }
       }
       }
     } else{
     } else{
       for (String s : allowedDevicesStr.split(",")) {
       for (String s : allowedDevicesStr.split(",")) {
         if (s.trim().length() > 0) {
         if (s.trim().length() > 0) {
-          minorNumbers.add(Integer.valueOf(s.trim()));
+          String[] kv = s.trim().split(":");
+          if (kv.length != 2) {
+            throw new YarnException(
+                "Illegal format, it should be index:minor_number format, now it="
+                    + s);
+          }
+
+          gpuDevices.add(
+              new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1])));
         }
         }
       }
       }
-      LOG.info("Allowed GPU devices with minor numbers:" + allowedDevicesStr);
+      LOG.info("Allowed GPU devices:" + gpuDevices);
     }
     }
 
 
-    return minorNumbers;
+    return gpuDevices;
   }
   }
 
 
   public synchronized void initialize(Configuration conf) throws YarnException {
   public synchronized void initialize(Configuration conf) throws YarnException {

+ 41 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java

@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+
+/**
+ * Factory to create GpuDocker Command Plugin instance
+ */
+public class GpuDockerCommandPluginFactory {
+  public static DockerCommandPlugin createGpuDockerCommandPlugin(
+      Configuration conf) throws YarnException {
+    String impl = conf.get(YarnConfiguration.NM_GPU_DOCKER_PLUGIN_IMPL,
+        YarnConfiguration.DEFAULT_NM_GPU_DOCKER_PLUGIN_IMPL);
+    if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V1)) {
+      return new NvidiaDockerV1CommandPlugin(conf);
+    }
+
+    throw new YarnException(
+        "Unkown implementation name for Gpu docker plugin, impl=" + impl);
+  }
+}

+ 6 - 4
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuNodeResourceUpdateHandler.java

@@ -40,12 +40,14 @@ public class GpuNodeResourceUpdateHandler extends NodeResourceUpdaterPlugin {
   public void updateConfiguredResource(Resource res) throws YarnException {
   public void updateConfiguredResource(Resource res) throws YarnException {
     LOG.info("Initializing configured GPU resources for the NodeManager.");
     LOG.info("Initializing configured GPU resources for the NodeManager.");
 
 
-    List<Integer> usableGpus =
-        GpuDiscoverer.getInstance().getMinorNumbersOfGpusUsableByYarn();
+    List<GpuDevice> usableGpus =
+        GpuDiscoverer.getInstance().getGpusUsableByYarn();
     if (null == usableGpus || usableGpus.isEmpty()) {
     if (null == usableGpus || usableGpus.isEmpty()) {
-      LOG.info("Didn't find any usable GPUs on the NodeManager.");
+      String message = "GPU is enabled, but couldn't find any usable GPUs on the "
+          + "NodeManager.";
+      LOG.error(message);
       // No gpu can be used by YARN.
       // No gpu can be used by YARN.
-      return;
+      throw new YarnException(message);
     }
     }
 
 
     long nUsableGpus = usableGpus.size();
     long nUsableGpus = usableGpus.size();

+ 9 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuResourcePlugin.java

@@ -24,17 +24,22 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileg
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceHandlerImpl;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceHandlerImpl;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.NodeResourceUpdaterPlugin;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.NodeResourceUpdaterPlugin;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePlugin;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePlugin;
 
 
 public class GpuResourcePlugin implements ResourcePlugin {
 public class GpuResourcePlugin implements ResourcePlugin {
   private ResourceHandler gpuResourceHandler = null;
   private ResourceHandler gpuResourceHandler = null;
   private GpuNodeResourceUpdateHandler resourceDiscoverHandler = null;
   private GpuNodeResourceUpdateHandler resourceDiscoverHandler = null;
+  private DockerCommandPlugin dockerCommandPlugin = null;
 
 
   @Override
   @Override
   public synchronized void initialize(Context context) throws YarnException {
   public synchronized void initialize(Context context) throws YarnException {
     resourceDiscoverHandler = new GpuNodeResourceUpdateHandler();
     resourceDiscoverHandler = new GpuNodeResourceUpdateHandler();
     GpuDiscoverer.getInstance().initialize(context.getConf());
     GpuDiscoverer.getInstance().initialize(context.getConf());
+    dockerCommandPlugin =
+        GpuDockerCommandPluginFactory.createGpuDockerCommandPlugin(
+            context.getConf());
   }
   }
 
 
   @Override
   @Override
@@ -58,4 +63,8 @@ public class GpuResourcePlugin implements ResourcePlugin {
   public void cleanup() throws YarnException {
   public void cleanup() throws YarnException {
     // Do nothing.
     // Do nothing.
   }
   }
+
+  public DockerCommandPlugin getDockerCommandPluginInstance() {
+    return dockerCommandPlugin;
+  }
 }
 }

+ 319 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV1CommandPlugin.java

@@ -0,0 +1,319 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.io.StringWriter;
+import java.net.URL;
+import java.net.URLConnection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand.VOLUME_NAME_PATTERN;
+
+/**
+ * Implementation to use nvidia-docker v1 as GPU docker command plugin.
+ */
+public class NvidiaDockerV1CommandPlugin implements DockerCommandPlugin {
+  final static Log LOG = LogFactory.getLog(NvidiaDockerV1CommandPlugin.class);
+
+  private Configuration conf;
+  private Map<String, Set<String>> additionalCommands = null;
+  private String volumeDriver = "local";
+
+  // Known option
+  private String DEVICE_OPTION = "--device";
+  private String VOLUME_DRIVER_OPTION = "--volume-driver";
+  private String MOUNT_RO_OPTION = "--volume";
+
+  public NvidiaDockerV1CommandPlugin(Configuration conf) {
+    this.conf = conf;
+  }
+
+  // Get value from key=value
+  // Throw exception if '=' not found
+  private String getValue(String input) throws IllegalArgumentException {
+    int index = input.indexOf('=');
+    if (index < 0) {
+      throw new IllegalArgumentException(
+          "Failed to locate '=' from input=" + input);
+    }
+    return input.substring(index + 1);
+  }
+
+  private void addToCommand(String key, String value) {
+    if (additionalCommands == null) {
+      additionalCommands = new HashMap<>();
+    }
+    if (!additionalCommands.containsKey(key)) {
+      additionalCommands.put(key, new HashSet<>());
+    }
+    additionalCommands.get(key).add(value);
+  }
+
+  private void init() throws ContainerExecutionException {
+    String endpoint = conf.get(
+        YarnConfiguration.NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT,
+        YarnConfiguration.DEFAULT_NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT);
+    if (null == endpoint || endpoint.isEmpty()) {
+      LOG.info(YarnConfiguration.NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT
+          + " set to empty, skip init ..");
+      return;
+    }
+    String cliOptions;
+    try {
+      // Talk to plugin server and get options
+      URL url = new URL(endpoint);
+      URLConnection uc = url.openConnection();
+      uc.setRequestProperty("X-Requested-With", "Curl");
+
+      StringWriter writer = new StringWriter();
+      IOUtils.copy(uc.getInputStream(), writer, "utf-8");
+      cliOptions = writer.toString();
+
+      LOG.info("Additional docker CLI options from plugin to run GPU "
+          + "containers:" + cliOptions);
+
+      // Parse cli options
+      // Examples like:
+      // --device=/dev/nvidiactl --device=/dev/nvidia-uvm --device=/dev/nvidia0
+      // --volume-driver=nvidia-docker
+      // --volume=nvidia_driver_352.68:/usr/local/nvidia:ro
+
+      for (String str : cliOptions.split(" ")) {
+        str = str.trim();
+        if (str.startsWith(DEVICE_OPTION)) {
+          addToCommand(DEVICE_OPTION, getValue(str));
+        } else if (str.startsWith(VOLUME_DRIVER_OPTION)) {
+          volumeDriver = getValue(str);
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("Found volume-driver:" + volumeDriver);
+          }
+        } else if (str.startsWith(MOUNT_RO_OPTION)) {
+          String mount = getValue(str);
+          if (!mount.endsWith(":ro")) {
+            throw new IllegalArgumentException(
+                "Should not have mount other than ro, command=" + str);
+          }
+          addToCommand(MOUNT_RO_OPTION,
+              mount.substring(0, mount.lastIndexOf(':')));
+        } else{
+          throw new IllegalArgumentException("Unsupported option:" + str);
+        }
+      }
+    } catch (RuntimeException e) {
+      LOG.warn(
+          "RuntimeException of " + this.getClass().getSimpleName() + " init:",
+          e);
+      throw new ContainerExecutionException(e);
+    } catch (IOException e) {
+      LOG.warn("IOException of " + this.getClass().getSimpleName() + " init:",
+          e);
+      throw new ContainerExecutionException(e);
+    }
+  }
+
+  private int getGpuIndexFromDeviceName(String device) {
+    final String NVIDIA = "nvidia";
+    int idx = device.lastIndexOf(NVIDIA);
+    if (idx < 0) {
+      return -1;
+    }
+    // Get last part
+    String str = device.substring(idx + NVIDIA.length());
+    for (int i = 0; i < str.length(); i++) {
+      if (!Character.isDigit(str.charAt(i))) {
+        return -1;
+      }
+    }
+    return Integer.parseInt(str);
+  }
+
+  private Set<GpuDevice> getAssignedGpus(Container container) {
+    ResourceMappings resourceMappings = container.getResourceMappings();
+
+    // Copy of assigned Resources
+    Set<GpuDevice> assignedResources = null;
+    if (resourceMappings != null) {
+      assignedResources = new HashSet<>();
+      for (Serializable s : resourceMappings.getAssignedResources(
+          ResourceInformation.GPU_URI)) {
+        assignedResources.add((GpuDevice) s);
+      }
+    }
+
+    if (assignedResources == null || assignedResources.isEmpty()) {
+      // When no GPU resource assigned, don't need to update docker command.
+      return Collections.emptySet();
+    }
+
+    return assignedResources;
+  }
+
+  @VisibleForTesting
+  protected boolean requestsGpu(Container container) {
+    return GpuResourceAllocator.getRequestedGpus(container.getResource()) > 0;
+  }
+
+  /**
+   * Do initialize when GPU requested
+   * @param container nmContainer
+   * @return if #GPU-requested > 0
+   * @throws ContainerExecutionException when any issue happens
+   */
+  private boolean initializeWhenGpuRequested(Container container)
+      throws ContainerExecutionException {
+    if (!requestsGpu(container)) {
+      return false;
+    }
+
+    // Do lazy initialization of gpu-docker plugin
+    if (additionalCommands == null) {
+      init();
+    }
+
+    return true;
+  }
+
+  @Override
+  public synchronized void updateDockerRunCommand(
+      DockerRunCommand dockerRunCommand, Container container)
+      throws ContainerExecutionException {
+    if (!initializeWhenGpuRequested(container)) {
+      return;
+    }
+
+    Set<GpuDevice> assignedResources = getAssignedGpus(container);
+    if (assignedResources == null || assignedResources.isEmpty()) {
+      return;
+    }
+
+    // Write to dockerRunCommand
+    for (Map.Entry<String, Set<String>> option : additionalCommands
+        .entrySet()) {
+      String key = option.getKey();
+      Set<String> values = option.getValue();
+      if (key.equals(DEVICE_OPTION)) {
+        int foundGpuDevices = 0;
+        for (String deviceName : values) {
+          // When specified is a GPU card (device name like /dev/nvidia[n]
+          // Get index of the GPU (which is [n]).
+          Integer gpuIdx = getGpuIndexFromDeviceName(deviceName);
+          if (gpuIdx >= 0) {
+            // Use assignedResources to filter --device given by
+            // nvidia-docker-plugin.
+            for (GpuDevice gpuDevice : assignedResources) {
+              if (gpuDevice.getIndex() == gpuIdx) {
+                foundGpuDevices++;
+                dockerRunCommand.addDevice(deviceName, deviceName);
+              }
+            }
+          } else{
+            // When gpuIdx < 0, it is a controller device (such as
+            // /dev/nvidiactl). In this case, add device directly.
+            dockerRunCommand.addDevice(deviceName, deviceName);
+          }
+        }
+
+        // Cannot get all assigned Gpu devices from docker plugin output
+        if (foundGpuDevices < assignedResources.size()) {
+          throw new ContainerExecutionException(
+              "Cannot get all assigned Gpu devices from docker plugin output");
+        }
+      } else if (key.equals(MOUNT_RO_OPTION)) {
+        for (String value : values) {
+          int idx = value.indexOf(':');
+          String source = value.substring(0, idx);
+          String target = value.substring(idx + 1);
+          dockerRunCommand.addReadOnlyMountLocation(source, target, true);
+        }
+      } else{
+        throw new ContainerExecutionException("Unsupported option:" + key);
+      }
+    }
+  }
+
+  @Override
+  public DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
+      throws ContainerExecutionException {
+    if (!initializeWhenGpuRequested(container)) {
+      return null;
+    }
+
+    String newVolumeName = null;
+
+    // Get volume name
+    Set<String> mounts = additionalCommands.get(MOUNT_RO_OPTION);
+    for (String mount : mounts) {
+      int idx = mount.indexOf(':');
+      if (idx >= 0) {
+        String mountSource = mount.substring(0, idx);
+        if (VOLUME_NAME_PATTERN.matcher(mountSource).matches()) {
+          // This is a valid named volume
+          newVolumeName = mountSource;
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("Found volume name for GPU:" + newVolumeName);
+          }
+          break;
+        } else{
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("Failed to match " + mountSource
+                + " to named-volume regex pattern");
+          }
+        }
+      }
+    }
+
+    if (newVolumeName != null) {
+      DockerVolumeCommand command = new DockerVolumeCommand(
+          DockerVolumeCommand.VOLUME_CREATE_COMMAND);
+      command.setDriverName(volumeDriver);
+      command.setVolumeName(newVolumeName);
+      return command;
+    }
+
+    return null;
+  }
+
+  @Override
+  public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
+      throws ContainerExecutionException {
+    // No cleanup needed.
+    return null;
+  }
+}

+ 33 - 29
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMLeveldbStateStoreService.java

@@ -18,27 +18,9 @@
 
 
 package org.apache.hadoop.yarn.server.nodemanager.recovery;
 package org.apache.hadoop.yarn.server.nodemanager.recovery;
 
 
-import static org.fusesource.leveldbjni.JniDBFactory.asString;
-import static org.fusesource.leveldbjni.JniDBFactory.bytes;
-
-import org.apache.hadoop.yarn.api.records.Token;
-import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.Timer;
-import java.util.TimerTask;
-import java.util.Set;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.ListMultimap;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.fs.Path;
@@ -50,9 +32,11 @@ import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.StartContainerRequestP
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.Token;
 import org.apache.hadoop.yarn.api.records.impl.pb.ResourcePBImpl;
 import org.apache.hadoop.yarn.api.records.impl.pb.ResourcePBImpl;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.proto.YarnProtos.LocalResourceProto;
 import org.apache.hadoop.yarn.proto.YarnProtos.LocalResourceProto;
+import org.apache.hadoop.yarn.proto.YarnSecurityTokenProtos.ContainerTokenIdentifierProto;
 import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.MasterKeyProto;
 import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.MasterKeyProto;
 import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
 import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.ContainerManagerApplicationProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.ContainerManagerApplicationProto;
@@ -60,9 +44,10 @@ import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.Deletion
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LocalizedResourceProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LocalizedResourceProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.proto.YarnServiceProtos.StartContainerRequestProto;
 import org.apache.hadoop.yarn.proto.YarnServiceProtos.StartContainerRequestProto;
-import org.apache.hadoop.yarn.proto.YarnSecurityTokenProtos.ContainerTokenIdentifierProto;
+import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.impl.pb.MasterKeyPBImpl;
 import org.apache.hadoop.yarn.server.api.records.impl.pb.MasterKeyPBImpl;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.records.Version;
 import org.apache.hadoop.yarn.server.records.Version;
 import org.apache.hadoop.yarn.server.records.impl.pb.VersionPBImpl;
 import org.apache.hadoop.yarn.server.records.impl.pb.VersionPBImpl;
@@ -74,10 +59,24 @@ import org.iq80.leveldb.DB;
 import org.iq80.leveldb.DBException;
 import org.iq80.leveldb.DBException;
 import org.iq80.leveldb.Options;
 import org.iq80.leveldb.Options;
 import org.iq80.leveldb.WriteBatch;
 import org.iq80.leveldb.WriteBatch;
+import org.slf4j.LoggerFactory;
 
 
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.ArrayListMultimap;
-import com.google.common.collect.ListMultimap;
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.Timer;
+import java.util.TimerTask;
+
+import static org.fusesource.leveldbjni.JniDBFactory.asString;
+import static org.fusesource.leveldbjni.JniDBFactory.bytes;
 
 
 public class NMLeveldbStateStoreService extends NMStateStoreService {
 public class NMLeveldbStateStoreService extends NMStateStoreService {
 
 
@@ -1173,15 +1172,17 @@ public class NMLeveldbStateStoreService extends NMStateStoreService {
   }
   }
 
 
   @Override
   @Override
-  public void storeAssignedResources(ContainerId containerId,
+  public void storeAssignedResources(Container container,
       String resourceType, List<Serializable> assignedResources)
       String resourceType, List<Serializable> assignedResources)
       throws IOException {
       throws IOException {
     if (LOG.isDebugEnabled()) {
     if (LOG.isDebugEnabled()) {
-      LOG.debug("storeAssignedResources: containerId=" + containerId
-          + ", assignedResources=" + StringUtils.join(",", assignedResources));
+      LOG.debug(
+          "storeAssignedResources: containerId=" + container.getContainerId()
+              + ", assignedResources=" + StringUtils
+              .join(",", assignedResources));
     }
     }
 
 
-    String keyResChng = CONTAINERS_KEY_PREFIX + containerId.toString()
+    String keyResChng = CONTAINERS_KEY_PREFIX + container.getContainerId().toString()
         + CONTAINER_ASSIGNED_RESOURCES_KEY_SUFFIX + resourceType;
         + CONTAINER_ASSIGNED_RESOURCES_KEY_SUFFIX + resourceType;
     try {
     try {
       WriteBatch batch = db.createWriteBatch();
       WriteBatch batch = db.createWriteBatch();
@@ -1199,6 +1200,9 @@ public class NMLeveldbStateStoreService extends NMStateStoreService {
     } catch (DBException e) {
     } catch (DBException e) {
       throw new IOException(e);
       throw new IOException(e);
     }
     }
+
+    // update container resource mapping.
+    updateContainerResourceMapping(container, resourceType, assignedResources);
   }
   }
 
 
   @SuppressWarnings("deprecation")
   @SuppressWarnings("deprecation")

+ 2 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMNullStateStoreService.java

@@ -35,6 +35,7 @@ import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.Localize
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 
 
 // The state store to use when state isn't being stored
 // The state store to use when state isn't being stored
 public class NMNullStateStoreService extends NMStateStoreService {
 public class NMNullStateStoreService extends NMStateStoreService {
@@ -268,7 +269,7 @@ public class NMNullStateStoreService extends NMStateStoreService {
   }
   }
 
 
   @Override
   @Override
-  public void storeAssignedResources(ContainerId containerId,
+  public void storeAssignedResources(Container container,
       String resourceType, List<Serializable> assignedResources)
       String resourceType, List<Serializable> assignedResources)
       throws IOException {
       throws IOException {
   }
   }

+ 13 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMStateStoreService.java

@@ -44,6 +44,7 @@ import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.Localize
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDeleterProto;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 
 
 @Private
 @Private
@@ -731,12 +732,12 @@ public abstract class NMStateStoreService extends AbstractService {
   /**
   /**
    * Store the assigned resources to a container.
    * Store the assigned resources to a container.
    *
    *
-   * @param containerId Container Id
+   * @param container NMContainer
    * @param resourceType Resource Type
    * @param resourceType Resource Type
    * @param assignedResources Assigned resources
    * @param assignedResources Assigned resources
    * @throws IOException if fails
    * @throws IOException if fails
    */
    */
-  public abstract void storeAssignedResources(ContainerId containerId,
+  public abstract void storeAssignedResources(Container container,
       String resourceType, List<Serializable> assignedResources)
       String resourceType, List<Serializable> assignedResources)
       throws IOException;
       throws IOException;
 
 
@@ -745,4 +746,14 @@ public abstract class NMStateStoreService extends AbstractService {
   protected abstract void startStorage() throws IOException;
   protected abstract void startStorage() throws IOException;
 
 
   protected abstract void closeStorage() throws IOException;
   protected abstract void closeStorage() throws IOException;
+
+  protected void updateContainerResourceMapping(Container container,
+      String resourceType, List<Serializable> assignedResources) {
+    // Update Container#getResourceMapping.
+    ResourceMappings.AssignedResources newAssigned =
+        new ResourceMappings.AssignedResources();
+    newAssigned.updateAssignedResources(assignedResources);
+    container.getResourceMappings().addAssignedResources(resourceType,
+        newAssigned);
+  }
 }
 }

+ 130 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/utils/docker-util.c

@@ -159,6 +159,11 @@ static int add_docker_config_param(const struct configuration *command_config, c
   return add_param_to_command(command_config, "docker-config", "--config=", 1, out, outlen);
   return add_param_to_command(command_config, "docker-config", "--config=", 1, out, outlen);
 }
 }
 
 
+static int validate_volume_name(const char *volume_name) {
+  const char *regex_str = "^[a-zA-Z0-9]([a-zA-Z0-9_.-]*)$";
+  return execute_regex_match(regex_str, volume_name);
+}
+
 static int validate_container_name(const char *container_name) {
 static int validate_container_name(const char *container_name) {
   const char *CONTAINER_NAME_PREFIX = "container_";
   const char *CONTAINER_NAME_PREFIX = "container_";
   if (0 == strncmp(container_name, CONTAINER_NAME_PREFIX, strlen(CONTAINER_NAME_PREFIX))) {
   if (0 == strncmp(container_name, CONTAINER_NAME_PREFIX, strlen(CONTAINER_NAME_PREFIX))) {
@@ -206,6 +211,12 @@ const char *get_docker_error_message(const int error_code) {
       return "Mount access error";
       return "Mount access error";
     case INVALID_DOCKER_DEVICE:
     case INVALID_DOCKER_DEVICE:
       return "Invalid docker device";
       return "Invalid docker device";
+    case INVALID_DOCKER_VOLUME_DRIVER:
+      return "Invalid docker volume-driver";
+    case INVALID_DOCKER_VOLUME_NAME:
+      return "Invalid docker volume name";
+    case INVALID_DOCKER_VOLUME_COMMAND:
+      return "Invalid docker volume command";
     default:
     default:
       return "Unknown error";
       return "Unknown error";
   }
   }
@@ -252,11 +263,125 @@ int get_docker_command(const char *command_file, const struct configuration *con
     return get_docker_run_command(command_file, conf, out, outlen);
     return get_docker_run_command(command_file, conf, out, outlen);
   } else if (strcmp(DOCKER_STOP_COMMAND, command) == 0) {
   } else if (strcmp(DOCKER_STOP_COMMAND, command) == 0) {
     return get_docker_stop_command(command_file, conf, out, outlen);
     return get_docker_stop_command(command_file, conf, out, outlen);
+  } else if (strcmp(DOCKER_VOLUME_COMMAND, command) == 0) {
+    return get_docker_volume_command(command_file, conf, out, outlen);
   } else {
   } else {
     return UNKNOWN_DOCKER_COMMAND;
     return UNKNOWN_DOCKER_COMMAND;
   }
   }
 }
 }
 
 
+// check if a key is permitted in the configuration
+// return 1 if permitted
+static int value_permitted(const struct configuration* executor_cfg,
+                           const char* key, const char* value) {
+  char **permitted_values = get_configuration_values_delimiter(key,
+    CONTAINER_EXECUTOR_CFG_DOCKER_SECTION, executor_cfg, ",");
+  if (!permitted_values) {
+    return 0;
+  }
+
+  char** permitted = permitted_values;
+  int found = 0;
+
+  while (*permitted) {
+    if (0 == strncmp(*permitted, value, 1024)) {
+      found = 1;
+      break;
+    }
+    permitted++;
+  }
+
+  free_values(permitted_values);
+
+  return found;
+}
+
+int get_docker_volume_command(const char *command_file, const struct configuration *conf, char *out,
+                               const size_t outlen) {
+  int ret = 0;
+  char *driver = NULL, *volume_name = NULL, *sub_command = NULL;
+  struct configuration command_config = {0, NULL};
+  ret = read_and_verify_command_file(command_file, DOCKER_VOLUME_COMMAND, &command_config);
+  if (ret != 0) {
+    return ret;
+  }
+  sub_command = get_configuration_value("sub-command", DOCKER_COMMAND_FILE_SECTION, &command_config);
+  if (sub_command == NULL || 0 != strcmp(sub_command, "create")) {
+    fprintf(ERRORFILE, "\"create\" is the only acceptable sub-command of volume.\n");
+    ret = INVALID_DOCKER_VOLUME_COMMAND;
+    goto cleanup;
+  }
+
+  volume_name = get_configuration_value("volume", DOCKER_COMMAND_FILE_SECTION, &command_config);
+  if (volume_name == NULL || validate_volume_name(volume_name) != 0) {
+    fprintf(ERRORFILE, "%s is not a valid volume name.\n", volume_name);
+    ret = INVALID_DOCKER_VOLUME_NAME;
+    goto cleanup;
+  }
+
+  driver = get_configuration_value("driver", DOCKER_COMMAND_FILE_SECTION, &command_config);
+  if (driver == NULL) {
+    ret = INVALID_DOCKER_VOLUME_DRIVER;
+    goto cleanup;
+  }
+
+  memset(out, 0, outlen);
+
+  ret = add_docker_config_param(&command_config, out, outlen);
+  if (ret != 0) {
+    ret = BUFFER_TOO_SMALL;
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, DOCKER_VOLUME_COMMAND);
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, " create");
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, " --name=");
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, volume_name);
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+  if (!value_permitted(conf, "docker.allowed.volume-drivers", driver)) {
+    fprintf(ERRORFILE, "%s is not permitted docker.allowed.volume-drivers\n",
+      driver);
+    ret = INVALID_DOCKER_VOLUME_DRIVER;
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, " --driver=");
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+  ret = add_to_buffer(out, outlen, driver);
+  if (ret != 0) {
+    goto cleanup;
+  }
+
+cleanup:
+  free(driver);
+  free(volume_name);
+  free(sub_command);
+
+  // clean up out buffer
+  if (ret != 0) {
+    out[0] = 0;
+  }
+  return ret;
+}
+
 int get_docker_inspect_command(const char *command_file, const struct configuration *conf, char *out,
 int get_docker_inspect_command(const char *command_file, const struct configuration *conf, char *out,
                                const size_t outlen) {
                                const size_t outlen) {
   const char *valid_format_strings[] = { "{{.State.Status}}",
   const char *valid_format_strings[] = { "{{.State.Status}}",
@@ -623,6 +748,11 @@ static char* normalize_mount(const char* mount) {
   }
   }
   real_mount = realpath(mount, NULL);
   real_mount = realpath(mount, NULL);
   if (real_mount == NULL) {
   if (real_mount == NULL) {
+    // If mount is a valid named volume, just return it and let docker decide
+    if (validate_volume_name(mount) == 0) {
+      return strdup(mount);
+    }
+
     fprintf(ERRORFILE, "Could not determine real path of mount '%s'\n", mount);
     fprintf(ERRORFILE, "Could not determine real path of mount '%s'\n", mount);
     free(real_mount);
     free(real_mount);
     return NULL;
     return NULL;

+ 17 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/utils/docker-util.h

@@ -30,6 +30,7 @@
 #define DOCKER_RM_COMMAND "rm"
 #define DOCKER_RM_COMMAND "rm"
 #define DOCKER_RUN_COMMAND "run"
 #define DOCKER_RUN_COMMAND "run"
 #define DOCKER_STOP_COMMAND "stop"
 #define DOCKER_STOP_COMMAND "stop"
+#define DOCKER_VOLUME_COMMAND "volume"
 
 
 
 
 enum docker_error_codes {
 enum docker_error_codes {
@@ -49,7 +50,10 @@ enum docker_error_codes {
     INVALID_DOCKER_RW_MOUNT,
     INVALID_DOCKER_RW_MOUNT,
     MOUNT_ACCESS_ERROR,
     MOUNT_ACCESS_ERROR,
     INVALID_DOCKER_DEVICE,
     INVALID_DOCKER_DEVICE,
-    INVALID_DOCKER_STOP_COMMAND
+    INVALID_DOCKER_STOP_COMMAND,
+    INVALID_DOCKER_VOLUME_DRIVER,
+    INVALID_DOCKER_VOLUME_NAME,
+    INVALID_DOCKER_VOLUME_COMMAND
 };
 };
 
 
 /**
 /**
@@ -130,6 +134,18 @@ int get_docker_run_command(const char* command_file, const struct configuration*
  */
  */
 int get_docker_stop_command(const char* command_file, const struct configuration* conf, char *out, const size_t outlen);
 int get_docker_stop_command(const char* command_file, const struct configuration* conf, char *out, const size_t outlen);
 
 
+/**
+ * Get the Docker volume command line string. The function will verify that the
+ * params file is meant for the volume command.
+ * @param command_file File containing the params for the Docker volume command
+ * @param conf Configuration struct containing the container-executor.cfg details
+ * @param out Buffer to fill with the volume command
+ * @param outlen Size of the output buffer
+ * @return Return code with 0 indicating success and non-zero codes indicating error
+ */
+int get_docker_volume_command(const char *command_file, const struct configuration *conf, char *out,
+                               const size_t outlen);
+
 /**
 /**
  * Give an error message for the supplied error code
  * Give an error message for the supplied error code
  * @param error_code the error code
  * @param error_code the error code

+ 42 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/utils/test_docker_util.cc

@@ -1120,4 +1120,46 @@ namespace ContainerExecutor {
     }
     }
   }
   }
 
 
+  TEST_F(TestDockerUtil, test_docker_volume_command) {
+    std::string container_executor_contents = "[docker]\n  docker.allowed.volume-drivers=driver1\n";
+    write_file(container_executor_cfg_file, container_executor_contents);
+    int ret = read_config(container_executor_cfg_file.c_str(), &container_executor_cfg);
+    if (ret != 0) {
+      FAIL();
+    }
+
+    std::vector<std::pair<std::string, std::string> > file_cmd_vec;
+    file_cmd_vec.push_back(std::make_pair<std::string, std::string>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=create\n  volume=volume1 \n driver=driver1",
+        "volume create --name=volume1 --driver=driver1"));
+
+    std::vector<std::pair<std::string, int> > bad_file_cmd_vec;
+
+    // Wrong subcommand
+    bad_file_cmd_vec.push_back(std::make_pair<std::string, int>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=ls\n  volume=volume1 \n driver=driver1",
+        static_cast<int>(INVALID_DOCKER_VOLUME_COMMAND)));
+
+    // Volume not specified
+    bad_file_cmd_vec.push_back(std::make_pair<std::string, int>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=create\n  driver=driver1",
+        static_cast<int>(INVALID_DOCKER_VOLUME_NAME)));
+
+    // Invalid volume name
+    bad_file_cmd_vec.push_back(std::make_pair<std::string, int>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=create\n  volume=/a/b/c \n driver=driver1",
+        static_cast<int>(INVALID_DOCKER_VOLUME_NAME)));
+
+    // Driver not specified
+    bad_file_cmd_vec.push_back(std::make_pair<std::string, int>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=create\n  volume=volume1 \n",
+        static_cast<int>(INVALID_DOCKER_VOLUME_DRIVER)));
+
+    // Invalid driver name
+    bad_file_cmd_vec.push_back(std::make_pair<std::string, int>(
+        "[docker-command-execution]\n  docker-command=volume\n  sub-command=create\n volume=volume1 \n driver=driver2",
+        static_cast<int>(INVALID_DOCKER_VOLUME_DRIVER)));
+
+    run_docker_command_test(file_cmd_vec, bad_file_cmd_vec, get_docker_volume_command);
+  }
 }
 }

+ 3 - 3
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestLinuxContainerExecutorWithMocks.java

@@ -158,7 +158,7 @@ public class TestLinuxContainerExecutorWithMocks {
         mockPrivilegedExec);
         mockPrivilegedExec);
     dirsHandler = new LocalDirsHandlerService();
     dirsHandler = new LocalDirsHandlerService();
     dirsHandler.init(conf);
     dirsHandler.init(conf);
-    linuxContainerRuntime.initialize(conf);
+    linuxContainerRuntime.initialize(conf, null);
     mockExec = new LinuxContainerExecutor(linuxContainerRuntime);
     mockExec = new LinuxContainerExecutor(linuxContainerRuntime);
     mockExec.setConf(conf);
     mockExec.setConf(conf);
     mockExecMockRuntime = new LinuxContainerExecutor(mockLinuxContainerRuntime);
     mockExecMockRuntime = new LinuxContainerExecutor(mockLinuxContainerRuntime);
@@ -315,7 +315,7 @@ public class TestLinuxContainerExecutorWithMocks {
           DefaultLinuxContainerRuntime(PrivilegedOperationExecutor.getInstance(
           DefaultLinuxContainerRuntime(PrivilegedOperationExecutor.getInstance(
               conf));
               conf));
 
 
-      linuxContainerRuntime.initialize(conf);
+      linuxContainerRuntime.initialize(conf, null);
       exec = new LinuxContainerExecutor(linuxContainerRuntime);
       exec = new LinuxContainerExecutor(linuxContainerRuntime);
 
 
       mockExec = spy(exec);
       mockExec = spy(exec);
@@ -545,7 +545,7 @@ public class TestLinuxContainerExecutorWithMocks {
             any(File.class), any(Map.class), anyBoolean(), anyBoolean());
             any(File.class), any(Map.class), anyBoolean(), anyBoolean());
     LinuxContainerRuntime runtime = new DefaultLinuxContainerRuntime(
     LinuxContainerRuntime runtime = new DefaultLinuxContainerRuntime(
         spyPrivilegedExecutor);
         spyPrivilegedExecutor);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
     mockExec = new LinuxContainerExecutor(runtime);
     mockExec = new LinuxContainerExecutor(runtime);
     mockExec.setConf(conf);
     mockExec.setConf(conf);
     LinuxContainerExecutor lce = new LinuxContainerExecutor(runtime) {
     LinuxContainerExecutor lce = new LinuxContainerExecutor(runtime) {

+ 5 - 4
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManagerRecovery.java

@@ -462,16 +462,18 @@ public class TestContainerManagerRecovery extends BaseContainerManagerTest {
 
 
     commonLaunchContainer(appId, cid, cm);
     commonLaunchContainer(appId, cid, cm);
 
 
+    Container nmContainer = context.getContainers().get(cid);
+
     Application app = context.getApplications().get(appId);
     Application app = context.getApplications().get(appId);
     assertNotNull(app);
     assertNotNull(app);
 
 
     // store resource mapping of the container
     // store resource mapping of the container
     List<Serializable> gpuResources = Arrays.asList("1", "2", "3");
     List<Serializable> gpuResources = Arrays.asList("1", "2", "3");
-    stateStore.storeAssignedResources(cid, "gpu", gpuResources);
+    stateStore.storeAssignedResources(nmContainer, "gpu", gpuResources);
     List<Serializable> numaResources = Arrays.asList("numa1");
     List<Serializable> numaResources = Arrays.asList("numa1");
-    stateStore.storeAssignedResources(cid, "numa", numaResources);
+    stateStore.storeAssignedResources(nmContainer, "numa", numaResources);
     List<Serializable> fpgaResources = Arrays.asList("fpga1", "fpga2");
     List<Serializable> fpgaResources = Arrays.asList("fpga1", "fpga2");
-    stateStore.storeAssignedResources(cid, "fpga", fpgaResources);
+    stateStore.storeAssignedResources(nmContainer, "fpga", fpgaResources);
 
 
     cm.stop();
     cm.stop();
     context = createContext(conf, stateStore);
     context = createContext(conf, stateStore);
@@ -483,7 +485,6 @@ public class TestContainerManagerRecovery extends BaseContainerManagerTest {
     app = context.getApplications().get(appId);
     app = context.getApplications().get(appId);
     assertNotNull(app);
     assertNotNull(app);
 
 
-    Container nmContainer = context.getContainers().get(cid);
     Assert.assertNotNull(nmContainer);
     Assert.assertNotNull(nmContainer);
     ResourceMappings resourceMappings = nmContainer.getResourceMappings();
     ResourceMappings resourceMappings = nmContainer.getResourceMappings();
     List<Serializable> assignedResource = resourceMappings
     List<Serializable> assignedResource = resourceMappings

+ 109 - 47
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandler.java

@@ -20,7 +20,6 @@ package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resourc
 
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.StringUtils;
-import org.apache.hadoop.yarn.api.protocolrecords.ResourceTypes;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
@@ -36,15 +35,17 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileg
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
-import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
 import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
 import org.junit.Assert;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
@@ -90,7 +91,7 @@ public class TestGpuResourceHandler {
   @Test
   @Test
   public void testBootStrap() throws Exception {
   public void testBootStrap() throws Exception {
     Configuration conf = new YarnConfiguration();
     Configuration conf = new YarnConfiguration();
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0");
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0");
 
 
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
@@ -104,8 +105,8 @@ public class TestGpuResourceHandler {
         .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
         .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
   }
   }
 
 
-  private static Container mockContainerWithGpuRequest(int id,
-      int numGpuRequest) {
+  private static Container mockContainerWithGpuRequest(int id, int numGpuRequest,
+      boolean dockerContainerEnabled) {
     Container c = mock(Container.class);
     Container c = mock(Container.class);
     when(c.getContainerId()).thenReturn(getContainerId(id));
     when(c.getContainerId()).thenReturn(getContainerId(id));
 
 
@@ -115,29 +116,46 @@ public class TestGpuResourceHandler {
     res.setResourceValue(ResourceInformation.GPU_URI, numGpuRequest);
     res.setResourceValue(ResourceInformation.GPU_URI, numGpuRequest);
     when(c.getResource()).thenReturn(res);
     when(c.getResource()).thenReturn(res);
     when(c.getResourceMappings()).thenReturn(resMapping);
     when(c.getResourceMappings()).thenReturn(resMapping);
+
+    ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
+    Map<String, String> env = new HashMap<>();
+    if (dockerContainerEnabled) {
+      env.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE, "docker");
+    }
+    when(clc.getEnvironment()).thenReturn(env);
+    when(c.getLaunchContext()).thenReturn(clc);
     return c;
     return c;
   }
   }
 
 
+  private static Container mockContainerWithGpuRequest(int id,
+      int numGpuRequest) {
+    return mockContainerWithGpuRequest(id, numGpuRequest, false);
+  }
+
   private void verifyDeniedDevices(ContainerId containerId,
   private void verifyDeniedDevices(ContainerId containerId,
-      List<Integer> deniedDevices)
+      List<GpuDevice> deniedDevices)
       throws ResourceHandlerException, PrivilegedOperationException {
       throws ResourceHandlerException, PrivilegedOperationException {
     verify(mockCGroupsHandler, times(1)).createCGroup(
     verify(mockCGroupsHandler, times(1)).createCGroup(
         CGroupsHandler.CGroupController.DEVICES, containerId.toString());
         CGroupsHandler.CGroupController.DEVICES, containerId.toString());
 
 
     if (null != deniedDevices && !deniedDevices.isEmpty()) {
     if (null != deniedDevices && !deniedDevices.isEmpty()) {
+      List<Integer> deniedDevicesMinorNumber = new ArrayList<>();
+      for (GpuDevice deniedDevice : deniedDevices) {
+        deniedDevicesMinorNumber.add(deniedDevice.getMinorNumber());
+      }
       verify(mockPrivilegedExecutor, times(1)).executePrivilegedOperation(
       verify(mockPrivilegedExecutor, times(1)).executePrivilegedOperation(
           new PrivilegedOperation(PrivilegedOperation.OperationType.GPU, Arrays
           new PrivilegedOperation(PrivilegedOperation.OperationType.GPU, Arrays
               .asList(GpuResourceHandlerImpl.CONTAINER_ID_CLI_OPTION,
               .asList(GpuResourceHandlerImpl.CONTAINER_ID_CLI_OPTION,
                   containerId.toString(),
                   containerId.toString(),
                   GpuResourceHandlerImpl.EXCLUDED_GPUS_CLI_OPTION,
                   GpuResourceHandlerImpl.EXCLUDED_GPUS_CLI_OPTION,
-                  StringUtils.join(",", deniedDevices))), true);
+                  StringUtils.join(",", deniedDevicesMinorNumber))), true);
     }
     }
   }
   }
 
 
-  @Test
-  public void testAllocation() throws Exception {
+  private void commonTestAllocation(boolean dockerContainerEnabled)
+      throws Exception {
     Configuration conf = new YarnConfiguration();
     Configuration conf = new YarnConfiguration();
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0,1,3,4");
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
     gpuResourceHandler.bootstrap(conf);
     gpuResourceHandler.bootstrap(conf);
@@ -145,31 +163,52 @@ public class TestGpuResourceHandler {
         gpuResourceHandler.getGpuAllocator().getAvailableGpus());
         gpuResourceHandler.getGpuAllocator().getAvailableGpus());
 
 
     /* Start container 1, asks 3 containers */
     /* Start container 1, asks 3 containers */
-    gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 3));
+    gpuResourceHandler.preStart(
+        mockContainerWithGpuRequest(1, 3, dockerContainerEnabled));
 
 
     // Only device=4 will be blocked.
     // Only device=4 will be blocked.
-    verifyDeniedDevices(getContainerId(1), Arrays.asList(4));
+    if (dockerContainerEnabled) {
+      verifyDeniedDevices(getContainerId(1), Collections.emptyList());
+    } else{
+      verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3,4)));
+    }
 
 
     /* Start container 2, asks 2 containers. Excepted to fail */
     /* Start container 2, asks 2 containers. Excepted to fail */
     boolean failedToAllocate = false;
     boolean failedToAllocate = false;
     try {
     try {
-      gpuResourceHandler.preStart(mockContainerWithGpuRequest(2, 2));
+      gpuResourceHandler.preStart(
+          mockContainerWithGpuRequest(2, 2, dockerContainerEnabled));
     } catch (ResourceHandlerException e) {
     } catch (ResourceHandlerException e) {
       failedToAllocate = true;
       failedToAllocate = true;
     }
     }
     Assert.assertTrue(failedToAllocate);
     Assert.assertTrue(failedToAllocate);
 
 
     /* Start container 3, ask 1 container, succeeded */
     /* Start container 3, ask 1 container, succeeded */
-    gpuResourceHandler.preStart(mockContainerWithGpuRequest(3, 1));
+    gpuResourceHandler.preStart(
+        mockContainerWithGpuRequest(3, 1, dockerContainerEnabled));
 
 
     // devices = 0/1/3 will be blocked
     // devices = 0/1/3 will be blocked
-    verifyDeniedDevices(getContainerId(3), Arrays.asList(0, 1, 3));
+    if (dockerContainerEnabled) {
+      verifyDeniedDevices(getContainerId(3), Collections.emptyList());
+    } else {
+      verifyDeniedDevices(getContainerId(3), Arrays
+          .asList(new GpuDevice(0, 0), new GpuDevice(1, 1),
+              new GpuDevice(2, 3)));
+    }
 
 
-    /* Start container 4, ask 0 container, succeeded */
-    gpuResourceHandler.preStart(mockContainerWithGpuRequest(4, 0));
 
 
-    // All devices will be blocked
-    verifyDeniedDevices(getContainerId(4), Arrays.asList(0, 1, 3, 4));
+    /* Start container 4, ask 0 container, succeeded */
+    gpuResourceHandler.preStart(
+        mockContainerWithGpuRequest(4, 0, dockerContainerEnabled));
+
+    if (dockerContainerEnabled) {
+      verifyDeniedDevices(getContainerId(4), Collections.emptyList());
+    } else{
+      // All devices will be blocked
+      verifyDeniedDevices(getContainerId(4), Arrays
+          .asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3),
+              new GpuDevice(3, 4)));
+    }
 
 
     /* Release container-1, expect cgroups deleted */
     /* Release container-1, expect cgroups deleted */
     gpuResourceHandler.postComplete(getContainerId(1));
     gpuResourceHandler.postComplete(getContainerId(1));
@@ -188,12 +227,24 @@ public class TestGpuResourceHandler {
         gpuResourceHandler.getGpuAllocator().getAvailableGpus());
         gpuResourceHandler.getGpuAllocator().getAvailableGpus());
   }
   }
 
 
+  @Test
+  public void testAllocationWhenDockerContainerEnabled() throws Exception {
+    // When docker container is enabled, no devices should be written to
+    // devices.deny.
+    commonTestAllocation(true);
+  }
+
+  @Test
+  public void testAllocation() throws Exception {
+    commonTestAllocation(false);
+  }
+
   @SuppressWarnings("unchecked")
   @SuppressWarnings("unchecked")
   @Test
   @Test
   public void testAssignedGpuWillBeCleanedupWhenStoreOpFails()
   public void testAssignedGpuWillBeCleanedupWhenStoreOpFails()
       throws Exception {
       throws Exception {
     Configuration conf = new YarnConfiguration();
     Configuration conf = new YarnConfiguration();
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0,1,3,4");
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
     gpuResourceHandler.bootstrap(conf);
     gpuResourceHandler.bootstrap(conf);
@@ -202,7 +253,7 @@ public class TestGpuResourceHandler {
 
 
     doThrow(new IOException("Exception ...")).when(mockNMStateStore)
     doThrow(new IOException("Exception ...")).when(mockNMStateStore)
         .storeAssignedResources(
         .storeAssignedResources(
-        any(ContainerId.class), anyString(), anyList());
+        any(Container.class), anyString(), anyList());
 
 
     boolean exception = false;
     boolean exception = false;
     /* Start container 1, asks 3 containers */
     /* Start container 1, asks 3 containers */
@@ -225,9 +276,12 @@ public class TestGpuResourceHandler {
     conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, " ");
     conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, " ");
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
-    gpuResourceHandler.bootstrap(conf);
-    Assert.assertEquals(0,
-        gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+    try {
+      gpuResourceHandler.bootstrap(conf);
+      Assert.fail("Should fail because no GPU available");
+    } catch (ResourceHandlerException e) {
+      // Expected because of no resource available
+    }
 
 
     /* Start container 1, asks 0 containers */
     /* Start container 1, asks 0 containers */
     gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 0));
     gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 0));
@@ -254,7 +308,7 @@ public class TestGpuResourceHandler {
   @Test
   @Test
   public void testAllocationStored() throws Exception {
   public void testAllocationStored() throws Exception {
     Configuration conf = new YarnConfiguration();
     Configuration conf = new YarnConfiguration();
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0,1,3,4");
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
     gpuResourceHandler.bootstrap(conf);
     gpuResourceHandler.bootstrap(conf);
@@ -265,33 +319,33 @@ public class TestGpuResourceHandler {
     Container container = mockContainerWithGpuRequest(1, 3);
     Container container = mockContainerWithGpuRequest(1, 3);
     gpuResourceHandler.preStart(container);
     gpuResourceHandler.preStart(container);
 
 
-    verify(mockNMStateStore).storeAssignedResources(getContainerId(1),
-        ResourceInformation.GPU_URI,
-        Arrays.asList("0", "1", "3"));
-
-    Assert.assertEquals(3, container.getResourceMappings()
-        .getAssignedResources(ResourceInformation.GPU_URI).size());
+    verify(mockNMStateStore).storeAssignedResources(container,
+        ResourceInformation.GPU_URI, Arrays
+            .asList(new GpuDevice(0, 0), new GpuDevice(1, 1),
+                new GpuDevice(2, 3)));
 
 
     // Only device=4 will be blocked.
     // Only device=4 will be blocked.
-    verifyDeniedDevices(getContainerId(1), Arrays.asList(4));
+    verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3, 4)));
 
 
     /* Start container 2, ask 0 container, succeeded */
     /* Start container 2, ask 0 container, succeeded */
     container = mockContainerWithGpuRequest(2, 0);
     container = mockContainerWithGpuRequest(2, 0);
     gpuResourceHandler.preStart(container);
     gpuResourceHandler.preStart(container);
 
 
-    verifyDeniedDevices(getContainerId(2), Arrays.asList(0, 1, 3, 4));
+    verifyDeniedDevices(getContainerId(2), Arrays
+        .asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3),
+            new GpuDevice(3, 4)));
     Assert.assertEquals(0, container.getResourceMappings()
     Assert.assertEquals(0, container.getResourceMappings()
         .getAssignedResources(ResourceInformation.GPU_URI).size());
         .getAssignedResources(ResourceInformation.GPU_URI).size());
 
 
     // Store assigned resource will not be invoked.
     // Store assigned resource will not be invoked.
     verify(mockNMStateStore, never()).storeAssignedResources(
     verify(mockNMStateStore, never()).storeAssignedResources(
-        eq(getContainerId(2)), eq(ResourceInformation.GPU_URI), anyList());
+        eq(container), eq(ResourceInformation.GPU_URI), anyList());
   }
   }
 
 
   @Test
   @Test
   public void testRecoverResourceAllocation() throws Exception {
   public void testRecoverResourceAllocation() throws Exception {
     Configuration conf = new YarnConfiguration();
     Configuration conf = new YarnConfiguration();
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0,1,3,4");
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
     GpuDiscoverer.getInstance().initialize(conf);
     GpuDiscoverer.getInstance().initialize(conf);
 
 
     gpuResourceHandler.bootstrap(conf);
     gpuResourceHandler.bootstrap(conf);
@@ -302,7 +356,8 @@ public class TestGpuResourceHandler {
     ResourceMappings rmap = new ResourceMappings();
     ResourceMappings rmap = new ResourceMappings();
     ResourceMappings.AssignedResources ar =
     ResourceMappings.AssignedResources ar =
         new ResourceMappings.AssignedResources();
         new ResourceMappings.AssignedResources();
-    ar.updateAssignedResources(Arrays.asList("1", "3"));
+    ar.updateAssignedResources(
+        Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3)));
     rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
     rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
 
 
@@ -312,12 +367,15 @@ public class TestGpuResourceHandler {
     // Reacquire container restore state of GPU Resource Allocator.
     // Reacquire container restore state of GPU Resource Allocator.
     gpuResourceHandler.reacquireContainer(getContainerId(1));
     gpuResourceHandler.reacquireContainer(getContainerId(1));
 
 
-    Map<Integer, ContainerId> deviceAllocationMapping =
+    Map<GpuDevice, ContainerId> deviceAllocationMapping =
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
     Assert.assertEquals(2, deviceAllocationMapping.size());
     Assert.assertEquals(2, deviceAllocationMapping.size());
     Assert.assertTrue(
     Assert.assertTrue(
-        deviceAllocationMapping.keySet().containsAll(Arrays.asList(1, 3)));
-    Assert.assertEquals(deviceAllocationMapping.get(1), getContainerId(1));
+        deviceAllocationMapping.keySet().contains(new GpuDevice(1, 1)));
+    Assert.assertTrue(
+        deviceAllocationMapping.keySet().contains(new GpuDevice(2, 3)));
+    Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+        getContainerId(1));
 
 
     // TEST CASE
     // TEST CASE
     // Try to reacquire a container but requested device is not in allowed list.
     // Try to reacquire a container but requested device is not in allowed list.
@@ -325,7 +383,8 @@ public class TestGpuResourceHandler {
     rmap = new ResourceMappings();
     rmap = new ResourceMappings();
     ar = new ResourceMappings.AssignedResources();
     ar = new ResourceMappings.AssignedResources();
     // id=5 is not in allowed list.
     // id=5 is not in allowed list.
-    ar.updateAssignedResources(Arrays.asList("4", "5"));
+    ar.updateAssignedResources(
+        Arrays.asList(new GpuDevice(3, 4), new GpuDevice(4, 5)));
     rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
     rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
 
 
@@ -345,9 +404,10 @@ public class TestGpuResourceHandler {
     deviceAllocationMapping =
     deviceAllocationMapping =
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
     Assert.assertEquals(2, deviceAllocationMapping.size());
     Assert.assertEquals(2, deviceAllocationMapping.size());
-    Assert.assertTrue(
-        deviceAllocationMapping.keySet().containsAll(Arrays.asList(1, 3)));
-    Assert.assertEquals(deviceAllocationMapping.get(1), getContainerId(1));
+    Assert.assertTrue(deviceAllocationMapping.keySet()
+        .containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
+    Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+        getContainerId(1));
 
 
     // TEST CASE
     // TEST CASE
     // Try to reacquire a container but requested device is already assigned.
     // Try to reacquire a container but requested device is already assigned.
@@ -355,7 +415,8 @@ public class TestGpuResourceHandler {
     rmap = new ResourceMappings();
     rmap = new ResourceMappings();
     ar = new ResourceMappings.AssignedResources();
     ar = new ResourceMappings.AssignedResources();
     // id=3 is already assigned
     // id=3 is already assigned
-    ar.updateAssignedResources(Arrays.asList("4", "3"));
+    ar.updateAssignedResources(
+        Arrays.asList(new GpuDevice(3, 4), new GpuDevice(2, 3)));
     rmap.addAssignedResources("gpu", ar);
     rmap.addAssignedResources("gpu", ar);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
     when(nmContainer.getResourceMappings()).thenReturn(rmap);
 
 
@@ -375,8 +436,9 @@ public class TestGpuResourceHandler {
     deviceAllocationMapping =
     deviceAllocationMapping =
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
         gpuResourceHandler.getGpuAllocator().getDeviceAllocationMapping();
     Assert.assertEquals(2, deviceAllocationMapping.size());
     Assert.assertEquals(2, deviceAllocationMapping.size());
-    Assert.assertTrue(
-        deviceAllocationMapping.keySet().containsAll(Arrays.asList(1, 3)));
-    Assert.assertEquals(deviceAllocationMapping.get(1), getContainerId(1));
+    Assert.assertTrue(deviceAllocationMapping.keySet()
+        .containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
+    Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+        getContainerId(1));
   }
   }
 }
 }

+ 7 - 7
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestDelegatingLinuxContainerRuntime.java

@@ -50,7 +50,7 @@ public class TestDelegatingLinuxContainerRuntime {
         YarnConfiguration.DEFAULT_LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES[0]);
         YarnConfiguration.DEFAULT_LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES[0]);
     System.out.println(conf.get(
     System.out.println(conf.get(
         YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES));
         YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES));
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT));
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT));
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
@@ -63,7 +63,7 @@ public class TestDelegatingLinuxContainerRuntime {
   public void testIsRuntimeAllowedDocker() throws Exception {
   public void testIsRuntimeAllowedDocker() throws Exception {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "docker");
         "docker");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER));
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER));
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
@@ -76,7 +76,7 @@ public class TestDelegatingLinuxContainerRuntime {
   public void testIsRuntimeAllowedJavaSandbox() throws Exception {
   public void testIsRuntimeAllowedJavaSandbox() throws Exception {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "javasandbox");
         "javasandbox");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.JAVASANDBOX));
         LinuxContainerRuntimeConstants.RuntimeType.JAVASANDBOX));
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertFalse(delegatingLinuxContainerRuntime.isRuntimeAllowed(
@@ -89,7 +89,7 @@ public class TestDelegatingLinuxContainerRuntime {
   public void testIsRuntimeAllowedMultiple() throws Exception {
   public void testIsRuntimeAllowedMultiple() throws Exception {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "docker,javasandbox");
         "docker,javasandbox");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER));
         LinuxContainerRuntimeConstants.RuntimeType.DOCKER));
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
@@ -102,7 +102,7 @@ public class TestDelegatingLinuxContainerRuntime {
   public void testIsRuntimeAllowedAll() throws Exception {
   public void testIsRuntimeAllowedAll() throws Exception {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "default,docker,javasandbox");
         "default,docker,javasandbox");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT));
         LinuxContainerRuntimeConstants.RuntimeType.DEFAULT));
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
     assertTrue(delegatingLinuxContainerRuntime.isRuntimeAllowed(
@@ -116,7 +116,7 @@ public class TestDelegatingLinuxContainerRuntime {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "default,docker");
         "default,docker");
     conf.set(YarnConfiguration.YARN_CONTAINER_SANDBOX, "permissive");
     conf.set(YarnConfiguration.YARN_CONTAINER_SANDBOX, "permissive");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     ContainerRuntime runtime =
     ContainerRuntime runtime =
         delegatingLinuxContainerRuntime.pickContainerRuntime(env);
         delegatingLinuxContainerRuntime.pickContainerRuntime(env);
     assertTrue(runtime instanceof DefaultLinuxContainerRuntime);
     assertTrue(runtime instanceof DefaultLinuxContainerRuntime);
@@ -129,7 +129,7 @@ public class TestDelegatingLinuxContainerRuntime {
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
     conf.set(YarnConfiguration.LINUX_CONTAINER_RUNTIME_ALLOWED_RUNTIMES,
         "default,docker");
         "default,docker");
     conf.set(YarnConfiguration.YARN_CONTAINER_SANDBOX, "permissive");
     conf.set(YarnConfiguration.YARN_CONTAINER_SANDBOX, "permissive");
-    delegatingLinuxContainerRuntime.initialize(conf);
+    delegatingLinuxContainerRuntime.initialize(conf, null);
     ContainerRuntime runtime =
     ContainerRuntime runtime =
         delegatingLinuxContainerRuntime.pickContainerRuntime(env);
         delegatingLinuxContainerRuntime.pickContainerRuntime(env);
     assertTrue(runtime instanceof DockerLinuxContainerRuntime);
     assertTrue(runtime instanceof DockerLinuxContainerRuntime);

+ 180 - 24
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestDockerContainerRuntime.java

@@ -20,15 +20,18 @@
 
 
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime;
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime;
 
 
+import org.apache.commons.io.IOUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileUtil;
 import org.apache.hadoop.fs.FileUtil;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.registry.client.binding.RegistryPathUtils;
 import org.apache.hadoop.registry.client.binding.RegistryPathUtils;
 import org.apache.hadoop.util.Shell;
 import org.apache.hadoop.util.Shell;
+import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
 import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.server.nodemanager.ContainerExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.ContainerExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
@@ -36,6 +39,10 @@ import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileg
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerModule;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerModule;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePluginManager;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeContext;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeContext;
@@ -48,22 +55,48 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
 import java.io.File;
 import java.io.File;
+import java.io.FileInputStream;
 import java.io.IOException;
 import java.io.IOException;
 import java.nio.charset.Charset;
 import java.nio.charset.Charset;
 import java.nio.file.Files;
 import java.nio.file.Files;
 import java.nio.file.Paths;
 import java.nio.file.Paths;
-import java.util.*;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.Random;
 import java.util.Set;
 import java.util.Set;
 
 
-import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.*;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.APPID;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.CONTAINER_ID_STR;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.CONTAINER_LOCAL_DIRS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.CONTAINER_LOG_DIRS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.CONTAINER_WORK_DIR;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.FILECACHE_DIRS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.LOCALIZED_RESOURCES;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.LOCAL_DIRS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.LOG_DIRS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.NM_PRIVATE_CONTAINER_SCRIPT_PATH;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.NM_PRIVATE_TOKENS_PATH;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.PID;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.PID_FILE_PATH;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.RESOURCES_OPTIONS;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.RUN_AS_USER;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.SIGNAL;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.USER;
+import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.LinuxContainerRuntimeConstants.USER_LOCAL_DIRS;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyBoolean;
+import static org.mockito.Mockito.anyList;
+import static org.mockito.Mockito.anyMap;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 
 public class TestDockerContainerRuntime {
 public class TestDockerContainerRuntime {
   private static final Logger LOG =
   private static final Logger LOG =
@@ -217,7 +250,7 @@ public class TestDockerContainerRuntime {
     return opCaptor.getValue();
     return opCaptor.getValue();
   }
   }
 
 
-  @SuppressWarnings("unchecked")
+    @SuppressWarnings("unchecked")
   private PrivilegedOperation capturePrivilegedOperationAndVerifyArgs()
   private PrivilegedOperation capturePrivilegedOperationAndVerifyArgs()
       throws PrivilegedOperationException {
       throws PrivilegedOperationException {
 
 
@@ -288,7 +321,7 @@ public class TestDockerContainerRuntime {
       IOException {
       IOException {
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
     runtime.launchContainer(builder.build());
     runtime.launchContainer(builder.build());
 
 
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
@@ -343,7 +376,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
     runtime.launchContainer(builder.build());
     runtime.launchContainer(builder.build());
 
 
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
@@ -425,7 +458,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime =
     DockerLinuxContainerRuntime runtime =
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     //invalid default network configuration - sdn2 is included in allowed
     //invalid default network configuration - sdn2 is included in allowed
     // networks
     // networks
@@ -441,7 +474,7 @@ public class TestDockerContainerRuntime {
     try {
     try {
       runtime =
       runtime =
           new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
           new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
-      runtime.initialize(conf);
+      runtime.initialize(conf, null);
       Assert.fail("Invalid default network configuration should did not "
       Assert.fail("Invalid default network configuration should did not "
           + "trigger initialization failure.");
           + "trigger initialization failure.");
     } catch (ContainerExecutionException e) {
     } catch (ContainerExecutionException e) {
@@ -457,7 +490,7 @@ public class TestDockerContainerRuntime {
         validDefaultNetwork);
         validDefaultNetwork);
     runtime =
     runtime =
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
   }
   }
 
 
   @Test
   @Test
@@ -467,7 +500,7 @@ public class TestDockerContainerRuntime {
       PrivilegedOperationException {
       PrivilegedOperationException {
     DockerLinuxContainerRuntime runtime =
     DockerLinuxContainerRuntime runtime =
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
         new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     Random randEngine = new Random();
     Random randEngine = new Random();
     String disallowedNetwork = "sdn" + Integer.toString(randEngine.nextInt());
     String disallowedNetwork = "sdn" + Integer.toString(randEngine.nextInt());
@@ -557,7 +590,7 @@ public class TestDockerContainerRuntime {
         customNetwork1);
         customNetwork1);
 
 
     //this should cause no failures.
     //this should cause no failures.
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
     runtime.launchContainer(builder.build());
     runtime.launchContainer(builder.build());
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
     PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
     List<String> args = op.getArguments();
     List<String> args = op.getArguments();
@@ -661,7 +694,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(DockerLinuxContainerRuntime
     env.put(DockerLinuxContainerRuntime
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "invalid-value");
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "invalid-value");
@@ -690,7 +723,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(DockerLinuxContainerRuntime
     env.put(DockerLinuxContainerRuntime
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
@@ -713,7 +746,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(DockerLinuxContainerRuntime
     env.put(DockerLinuxContainerRuntime
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
@@ -743,7 +776,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(DockerLinuxContainerRuntime
     env.put(DockerLinuxContainerRuntime
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
@@ -770,7 +803,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(DockerLinuxContainerRuntime
     env.put(DockerLinuxContainerRuntime
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
             .ENV_DOCKER_CONTAINER_RUN_PRIVILEGED_CONTAINER, "true");
@@ -822,7 +855,7 @@ public class TestDockerContainerRuntime {
 
 
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime
         (mockExecutor, mockCGroupsHandler);
         (mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     String resourceOptionsNone = "cgroups=none";
     String resourceOptionsNone = "cgroups=none";
     DockerRunCommand command = Mockito.mock(DockerRunCommand.class);
     DockerRunCommand command = Mockito.mock(DockerRunCommand.class);
@@ -849,7 +882,7 @@ public class TestDockerContainerRuntime {
 
 
     runtime = new DockerLinuxContainerRuntime
     runtime = new DockerLinuxContainerRuntime
         (mockExecutor, null);
         (mockExecutor, null);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     runtime.addCGroupParentIfRequired(resourceOptionsNone, containerIdStr,
     runtime.addCGroupParentIfRequired(resourceOptionsNone, containerIdStr,
         command);
         command);
@@ -866,7 +899,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(
     env.put(
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
@@ -886,7 +919,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(
     env.put(
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
@@ -935,7 +968,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(
     env.put(
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
@@ -955,7 +988,7 @@ public class TestDockerContainerRuntime {
       IOException{
       IOException{
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
     DockerLinuxContainerRuntime runtime = new DockerLinuxContainerRuntime(
         mockExecutor, mockCGroupsHandler);
         mockExecutor, mockCGroupsHandler);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     env.put(
     env.put(
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
         DockerLinuxContainerRuntime.ENV_DOCKER_CONTAINER_LOCAL_RESOURCE_MOUNTS,
@@ -1011,7 +1044,7 @@ public class TestDockerContainerRuntime {
         .setExecutionAttribute(USER, user)
         .setExecutionAttribute(USER, user)
         .setExecutionAttribute(PID, signalPid)
         .setExecutionAttribute(PID, signalPid)
         .setExecutionAttribute(SIGNAL, ContainerExecutor.Signal.NULL);
         .setExecutionAttribute(SIGNAL, ContainerExecutor.Signal.NULL);
-    runtime.initialize(enableMockContainerExecutor(conf));
+    runtime.initialize(enableMockContainerExecutor(conf), null);
     runtime.signalContainer(builder.build());
     runtime.signalContainer(builder.build());
 
 
     PrivilegedOperation op = capturePrivilegedOperation();
     PrivilegedOperation op = capturePrivilegedOperation();
@@ -1071,7 +1104,7 @@ public class TestDockerContainerRuntime {
         .setExecutionAttribute(USER, user)
         .setExecutionAttribute(USER, user)
         .setExecutionAttribute(PID, signalPid)
         .setExecutionAttribute(PID, signalPid)
         .setExecutionAttribute(SIGNAL, signal);
         .setExecutionAttribute(SIGNAL, signal);
-    runtime.initialize(enableMockContainerExecutor(conf));
+    runtime.initialize(enableMockContainerExecutor(conf), null);
     runtime.signalContainer(builder.build());
     runtime.signalContainer(builder.build());
 
 
     PrivilegedOperation op = capturePrivilegedOperation();
     PrivilegedOperation op = capturePrivilegedOperation();
@@ -1148,4 +1181,127 @@ public class TestDockerContainerRuntime {
       }
       }
     }
     }
   }
   }
+
+  @SuppressWarnings("unchecked")
+  private void checkVolumeCreateCommand()
+      throws PrivilegedOperationException, IOException {
+    ArgumentCaptor<PrivilegedOperation> opCaptor = ArgumentCaptor.forClass(
+        PrivilegedOperation.class);
+
+    //single invocation expected
+    //due to type erasure + mocking, this verification requires a suppress
+    // warning annotation on the entire method
+    verify(mockExecutor, times(1))
+        .executePrivilegedOperation(anyList(), opCaptor.capture(), any(
+            File.class), anyMap(), anyBoolean(), anyBoolean());
+
+    //verification completed. we need to isolate specific invications.
+    // hence, reset mock here
+    Mockito.reset(mockExecutor);
+
+    PrivilegedOperation op = opCaptor.getValue();
+    Assert.assertEquals(PrivilegedOperation.OperationType
+        .RUN_DOCKER_CMD, op.getOperationType());
+
+    File commandFile = new File(StringUtils.join(",", op.getArguments()));
+    FileInputStream fileInputStream = new FileInputStream(commandFile);
+    String fileContent = new String(IOUtils.toByteArray(fileInputStream));
+    Assert.assertEquals("[docker-command-execution]\n"
+        + "  docker-command=volume\n" + "  sub-command=create\n"
+        + "  volume=volume1\n", fileContent);
+  }
+
+  @Test
+  public void testDockerCommandPlugin() throws Exception {
+    DockerLinuxContainerRuntime runtime =
+        new DockerLinuxContainerRuntime(mockExecutor, mockCGroupsHandler);
+
+    Context nmContext = mock(Context.class);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    Map<String, ResourcePlugin> pluginsMap = new HashMap<>();
+    ResourcePlugin plugin1 = mock(ResourcePlugin.class);
+
+    // Create the docker command plugin logic, which will set volume driver
+    DockerCommandPlugin dockerCommandPlugin = new DockerCommandPlugin() {
+      @Override
+      public void updateDockerRunCommand(DockerRunCommand dockerRunCommand,
+          Container container) throws ContainerExecutionException {
+        dockerRunCommand.setVolumeDriver("driver-1");
+        dockerRunCommand.addReadOnlyMountLocation("/source/path",
+            "/destination/path", true);
+      }
+
+      @Override
+      public DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
+          throws ContainerExecutionException {
+        return new DockerVolumeCommand("create").setVolumeName("volume1");
+      }
+
+      @Override
+      public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
+          throws ContainerExecutionException {
+        return null;
+      }
+    };
+
+    when(plugin1.getDockerCommandPluginInstance()).thenReturn(
+        dockerCommandPlugin);
+    ResourcePlugin plugin2 = mock(ResourcePlugin.class);
+    pluginsMap.put("plugin1", plugin1);
+    pluginsMap.put("plugin2", plugin2);
+
+    when(rpm.getNameToPlugins()).thenReturn(pluginsMap);
+
+    when(nmContext.getResourcePluginManager()).thenReturn(rpm);
+
+    runtime.initialize(conf, nmContext);
+
+    ContainerRuntimeContext containerRuntimeContext = builder.build();
+
+    runtime.prepareContainer(containerRuntimeContext);
+    checkVolumeCreateCommand();
+
+    runtime.launchContainer(containerRuntimeContext);
+    PrivilegedOperation op = capturePrivilegedOperationAndVerifyArgs();
+    List<String> args = op.getArguments();
+    String dockerCommandFile = args.get(11);
+
+    List<String> dockerCommands = Files.readAllLines(Paths.get
+        (dockerCommandFile), Charset.forName("UTF-8"));
+
+    int expected = 15;
+    int counter = 0;
+    Assert.assertEquals(expected, dockerCommands.size());
+    Assert.assertEquals("[docker-command-execution]",
+        dockerCommands.get(counter++));
+    Assert.assertEquals("  cap-add=SYS_CHROOT,NET_BIND_SERVICE",
+        dockerCommands.get(counter++));
+    Assert.assertEquals("  cap-drop=ALL", dockerCommands.get(counter++));
+    Assert.assertEquals("  detach=true", dockerCommands.get(counter++));
+    Assert.assertEquals("  docker-command=run", dockerCommands.get(counter++));
+    Assert.assertEquals("  hostname=ctr-id", dockerCommands.get(counter++));
+    Assert
+        .assertEquals("  image=busybox:latest", dockerCommands.get(counter++));
+    Assert.assertEquals(
+        "  launch-command=bash,/test_container_work_dir/launch_container.sh",
+        dockerCommands.get(counter++));
+    Assert.assertEquals("  name=container_id", dockerCommands.get(counter++));
+    Assert.assertEquals("  net=host", dockerCommands.get(counter++));
+    Assert.assertEquals("  ro-mounts=/source/path:/destination/path",
+        dockerCommands.get(counter++));
+    Assert.assertEquals(
+        "  rw-mounts=/test_container_local_dir:/test_container_local_dir,"
+            + "/test_filecache_dir:/test_filecache_dir,"
+            + "/test_container_work_dir:/test_container_work_dir,"
+            + "/test_container_log_dir:/test_container_log_dir,"
+            + "/test_user_local_dir:/test_user_local_dir",
+        dockerCommands.get(counter++));
+    Assert.assertEquals("  user=run_as_user", dockerCommands.get(counter++));
+
+    // Verify volume-driver is set to expected value.
+    Assert.assertEquals("  volume-driver=driver-1",
+        dockerCommands.get(counter++));
+    Assert.assertEquals("  workdir=/test_container_work_dir",
+        dockerCommands.get(counter++));
+  }
 }
 }

+ 1 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/TestJavaSandboxLinuxContainerRuntime.java

@@ -55,7 +55,6 @@ import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import java.util.regex.Pattern;
 
 
 import static org.apache.hadoop.yarn.api.ApplicationConstants.Environment.JAVA_HOME;
 import static org.apache.hadoop.yarn.api.ApplicationConstants.Environment.JAVA_HOME;
-import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.LOG;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.MULTI_COMMAND_REGEX;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.MULTI_COMMAND_REGEX;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.CLEAN_CMD_REGEX;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.CLEAN_CMD_REGEX;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.CONTAINS_JAVA_CMD;
 import static org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.JavaSandboxLinuxContainerRuntime.NMContainerPolicyUtils.CONTAINS_JAVA_CMD;
@@ -134,7 +133,7 @@ public class TestJavaSandboxLinuxContainerRuntime {
 
 
     mockExecutor = mock(PrivilegedOperationExecutor.class);
     mockExecutor = mock(PrivilegedOperationExecutor.class);
     runtime = new JavaSandboxLinuxContainerRuntime(mockExecutor);
     runtime = new JavaSandboxLinuxContainerRuntime(mockExecutor);
-    runtime.initialize(conf);
+    runtime.initialize(conf, null);
 
 
     resources = new HashMap<>();
     resources = new HashMap<>();
     grantDir = new File(baseTestDirectory, "grantDir");
     grantDir = new File(baseTestDirectory, "grantDir");

+ 2 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/TestDockerCommandExecutor.java

@@ -85,7 +85,8 @@ public class TestDockerCommandExecutor {
 
 
     builder.setExecutionAttribute(CONTAINER_ID_STR, MOCK_CONTAINER_ID);
     builder.setExecutionAttribute(CONTAINER_ID_STR, MOCK_CONTAINER_ID);
     runtime.initialize(
     runtime.initialize(
-        TestDockerContainerRuntime.enableMockContainerExecutor(configuration));
+        TestDockerContainerRuntime.enableMockContainerExecutor(configuration),
+        null);
   }
   }
 
 
   @Test
   @Test

+ 45 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/runtime/docker/TestDockerVolumeCommand.java

@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestDockerVolumeCommand {
+  @Test
+  public void testDockerVolumeCommand() {
+    DockerVolumeCommand dockerVolumeCommand = new DockerVolumeCommand("create");
+    assertEquals("volume", dockerVolumeCommand.getCommandOption());
+    Assert.assertTrue(
+        dockerVolumeCommand.getDockerCommandWithArguments().get("sub-command")
+            .contains("create"));
+
+    dockerVolumeCommand.setDriverName("driver1");
+    dockerVolumeCommand.setVolumeName("volume1");
+
+    Assert.assertTrue(
+        dockerVolumeCommand.getDockerCommandWithArguments().get("driver")
+            .contains("driver1"));
+
+    Assert.assertTrue(
+        dockerVolumeCommand.getDockerCommandWithArguments().get("volume")
+            .contains("volume1"));
+  }
+}

+ 26 - 8
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java

@@ -101,23 +101,41 @@ public class TestGpuDiscoverer {
     GpuDeviceInformation info = plugin.getGpuDeviceInformation();
     GpuDeviceInformation info = plugin.getGpuDeviceInformation();
 
 
     Assert.assertTrue(info.getGpus().size() > 0);
     Assert.assertTrue(info.getGpus().size() > 0);
-    Assert.assertEquals(plugin.getMinorNumbersOfGpusUsableByYarn().size(),
+    Assert.assertEquals(plugin.getGpusUsableByYarn().size(),
         info.getGpus().size());
         info.getGpus().size());
   }
   }
 
 
   @Test
   @Test
   public void getNumberOfUsableGpusFromConfig() throws YarnException {
   public void getNumberOfUsableGpusFromConfig() throws YarnException {
     Configuration conf = new Configuration(false);
     Configuration conf = new Configuration(false);
-    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0,1,2,4");
+
+    // Illegal format
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:2,3");
     GpuDiscoverer plugin = new GpuDiscoverer();
     GpuDiscoverer plugin = new GpuDiscoverer();
+    try {
+      plugin.initialize(conf);
+      plugin.getGpusUsableByYarn();
+      Assert.fail("Illegal format, should fail.");
+    } catch (YarnException e) {
+      // Expected
+    }
+
+    // Valid format
+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:2,3:4");
+    plugin = new GpuDiscoverer();
     plugin.initialize(conf);
     plugin.initialize(conf);
 
 
-    List<Integer> minorNumbers = plugin.getMinorNumbersOfGpusUsableByYarn();
-    Assert.assertEquals(4, minorNumbers.size());
+    List<GpuDevice> usableGpuDevices = plugin.getGpusUsableByYarn();
+    Assert.assertEquals(4, usableGpuDevices.size());
+
+    Assert.assertTrue(0 == usableGpuDevices.get(0).getIndex());
+    Assert.assertTrue(1 == usableGpuDevices.get(1).getIndex());
+    Assert.assertTrue(2 == usableGpuDevices.get(2).getIndex());
+    Assert.assertTrue(3 == usableGpuDevices.get(3).getIndex());
 
 
-    Assert.assertTrue(0 == minorNumbers.get(0));
-    Assert.assertTrue(1 == minorNumbers.get(1));
-    Assert.assertTrue(2 == minorNumbers.get(2));
-    Assert.assertTrue(4 == minorNumbers.get(3));
+    Assert.assertTrue(0 == usableGpuDevices.get(0).getMinorNumber());
+    Assert.assertTrue(1 == usableGpuDevices.get(1).getMinorNumber());
+    Assert.assertTrue(2 == usableGpuDevices.get(2).getMinorNumber());
+    Assert.assertTrue(4 == usableGpuDevices.get(3).getMinorNumber());
   }
   }
 }
 }

+ 217 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV1CommandPlugin.java

@@ -0,0 +1,217 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import com.sun.net.httpserver.HttpExchange;
+import com.sun.net.httpserver.HttpHandler;
+import com.sun.net.httpserver.HttpServer;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestNvidiaDockerV1CommandPlugin {
+  private Map<String, List<String>> copyCommandLine(
+      Map<String, List<String>> map) {
+    Map<String, List<String>> ret = new HashMap<>();
+    for (Map.Entry<String, List<String>> entry : map.entrySet()) {
+      ret.put(entry.getKey(), new ArrayList<>(entry.getValue()));
+    }
+    return ret;
+  }
+
+  private boolean commandlinesEquals(Map<String, List<String>> cli1,
+      Map<String, List<String>> cli2) {
+    if (!Sets.symmetricDifference(cli1.keySet(), cli2.keySet()).isEmpty()) {
+      return false;
+    }
+
+    for (String key : cli1.keySet()) {
+      List<String> value1 = cli1.get(key);
+      List<String> value2 = cli2.get(key);
+      if (!value1.equals(value2)) {
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+  static class MyHandler implements HttpHandler {
+    String response = "This is the response";
+
+    @Override
+    public void handle(HttpExchange t) throws IOException {
+      t.sendResponseHeaders(200, response.length());
+      OutputStream os = t.getResponseBody();
+      os.write(response.getBytes());
+      os.close();
+    }
+  }
+
+  static class MyNvidiaDockerV1CommandPlugin
+      extends NvidiaDockerV1CommandPlugin {
+    private boolean requestsGpu = false;
+
+    public MyNvidiaDockerV1CommandPlugin(Configuration conf) {
+      super(conf);
+    }
+
+    public void setRequestsGpu(boolean r) {
+      requestsGpu = r;
+    }
+
+    @Override
+    protected boolean requestsGpu(Container container) {
+      return requestsGpu;
+    }
+  }
+
+  @Test
+  public void testPlugin() throws Exception {
+    Configuration conf = new Configuration();
+
+    DockerRunCommand runCommand = new DockerRunCommand("container_1", "user",
+        "fakeimage");
+
+    Map<String, List<String>> originalCommandline = copyCommandLine(
+        runCommand.getDockerCommandWithArguments());
+
+    MyNvidiaDockerV1CommandPlugin
+        commandPlugin = new MyNvidiaDockerV1CommandPlugin(conf);
+
+    Container nmContainer = mock(Container.class);
+
+    // getResourceMapping is null, so commandline won't be updated
+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    Assert.assertTrue(commandlinesEquals(originalCommandline,
+        runCommand.getDockerCommandWithArguments()));
+
+    // no GPU resource assigned, so commandline won't be updated
+    ResourceMappings resourceMappings = new ResourceMappings();
+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);
+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    Assert.assertTrue(commandlinesEquals(originalCommandline,
+        runCommand.getDockerCommandWithArguments()));
+
+    // Assign GPU resource, init will be invoked
+    ResourceMappings.AssignedResources assigned =
+        new ResourceMappings.AssignedResources();
+    assigned.updateAssignedResources(
+        ImmutableList.of(new GpuDevice(0, 0), new GpuDevice(1, 1)));
+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,
+        assigned);
+
+    commandPlugin.setRequestsGpu(true);
+
+    // Since there's no HTTP server running, so we will see exception
+    boolean caughtException = false;
+    try {
+      commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    } catch (ContainerExecutionException e) {
+      caughtException = true;
+    }
+    Assert.assertTrue(caughtException);
+
+    // Start HTTP server
+    MyHandler handler = new MyHandler();
+    HttpServer server = HttpServer.create(new InetSocketAddress(60111), 0);
+    server.createContext("/test", handler);
+    server.start();
+
+    String hostName = server.getAddress().getHostName();
+    int port = server.getAddress().getPort();
+    String httpUrl = "http://" + hostName + ":" + port + "/test";
+
+    conf.set(YarnConfiguration.NVIDIA_DOCKER_PLUGIN_V1_ENDPOINT, httpUrl);
+
+    commandPlugin = new MyNvidiaDockerV1CommandPlugin(conf);
+
+    // Start use invalid options
+    handler.response = "INVALID_RESPONSE";
+    try {
+      commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    } catch (ContainerExecutionException e) {
+      caughtException = true;
+    }
+    Assert.assertTrue(caughtException);
+
+    // Start use invalid options
+    handler.response = "INVALID_RESPONSE";
+    try {
+      commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    } catch (ContainerExecutionException e) {
+      caughtException = true;
+    }
+    Assert.assertTrue(caughtException);
+
+    /* Test get docker run command */
+    handler.response = "--device=/dev/nvidiactl --device=/dev/nvidia-uvm "
+        + "--device=/dev/nvidia0 --device=/dev/nvidia1 "
+        + "--volume-driver=nvidia-docker "
+        + "--volume=nvidia_driver_352.68:/usr/local/nvidia:ro";
+
+    commandPlugin.setRequestsGpu(true);
+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);
+    Map<String, List<String>> newCommandLine =
+        runCommand.getDockerCommandWithArguments();
+
+    // Command line will be updated
+    Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine));
+    // Volume driver should not be included by final commandline
+    Assert.assertFalse(newCommandLine.containsKey("volume-driver"));
+    Assert.assertTrue(newCommandLine.containsKey("devices"));
+    Assert.assertTrue(newCommandLine.containsKey("ro-mounts"));
+
+    /* Test get docker volume command */
+    commandPlugin = new MyNvidiaDockerV1CommandPlugin(conf);
+
+    // When requests Gpu == false, returned docker volume command is null,
+    Assert.assertNull(commandPlugin.getCreateDockerVolumeCommand(nmContainer));
+
+    // set requests Gpu to true
+    commandPlugin.setRequestsGpu(true);
+
+    DockerVolumeCommand dockerVolumeCommand = commandPlugin.getCreateDockerVolumeCommand(
+        nmContainer);
+    Assert.assertEquals(
+        "volume docker-command=volume " + "driver=nvidia-docker "
+            + "sub-command=create " + "volume=nvidia_driver_352.68",
+        dockerVolumeCommand.toString());
+  }
+}

+ 6 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/recovery/NMMemoryStateStoreService.java

@@ -42,6 +42,7 @@ import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDelet
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.impl.pb.MasterKeyPBImpl;
 import org.apache.hadoop.yarn.server.api.records.impl.pb.MasterKeyPBImpl;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 
 
 public class NMMemoryStateStoreService extends NMStateStoreService {
 public class NMMemoryStateStoreService extends NMStateStoreService {
@@ -503,14 +504,17 @@ public class NMMemoryStateStoreService extends NMStateStoreService {
   }
   }
 
 
   @Override
   @Override
-  public void storeAssignedResources(ContainerId containerId,
+  public void storeAssignedResources(Container container,
       String resourceType, List<Serializable> assignedResources)
       String resourceType, List<Serializable> assignedResources)
       throws IOException {
       throws IOException {
     ResourceMappings.AssignedResources ar =
     ResourceMappings.AssignedResources ar =
         new ResourceMappings.AssignedResources();
         new ResourceMappings.AssignedResources();
     ar.updateAssignedResources(assignedResources);
     ar.updateAssignedResources(assignedResources);
-    containerStates.get(containerId).getResourceMappings()
+    containerStates.get(container.getContainerId()).getResourceMappings()
         .addAssignedResources(resourceType, ar);
         .addAssignedResources(resourceType, ar);
+
+    // update container resource mapping.
+    updateContainerResourceMapping(container, resourceType, assignedResources);
   }
   }
 
 
   private static class TrackerState {
   private static class TrackerState {

+ 18 - 4
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/recovery/TestNMLeveldbStateStoreService.java

@@ -29,6 +29,7 @@ import static org.mockito.Mockito.isNull;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 
 import java.io.File;
 import java.io.File;
 import java.io.IOException;
 import java.io.IOException;
@@ -69,6 +70,8 @@ import org.apache.hadoop.yarn.proto.YarnServerNodemanagerRecoveryProtos.LogDelet
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.security.ContainerTokenIdentifier;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.api.records.MasterKey;
 import org.apache.hadoop.yarn.server.nodemanager.amrmproxy.AMRMProxyTokenSecretManager;
 import org.apache.hadoop.yarn.server.nodemanager.amrmproxy.AMRMProxyTokenSecretManager;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.LocalResourceTrackerState;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.LocalResourceTrackerState;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.RecoveredAMRMProxyState;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.RecoveredAMRMProxyState;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.RecoveredApplicationsState;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService.RecoveredApplicationsState;
@@ -1124,16 +1127,21 @@ public class TestNMLeveldbStateStoreService {
     ContainerId containerId = ContainerId.newContainerId(appAttemptId, 5);
     ContainerId containerId = ContainerId.newContainerId(appAttemptId, 5);
     storeMockContainer(containerId);
     storeMockContainer(containerId);
 
 
+    Container container = mock(Container.class);
+    when(container.getContainerId()).thenReturn(containerId);
+    ResourceMappings resourceMappings = new ResourceMappings();
+    when(container.getResourceMappings()).thenReturn(resourceMappings);
+
     // Store ResourceMapping
     // Store ResourceMapping
-    stateStore.storeAssignedResources(containerId, "gpu",
+    stateStore.storeAssignedResources(container, "gpu",
         Arrays.asList("1", "2", "3"));
         Arrays.asList("1", "2", "3"));
     // This will overwrite above
     // This will overwrite above
     List<Serializable> gpuRes1 = Arrays.asList("1", "2", "4");
     List<Serializable> gpuRes1 = Arrays.asList("1", "2", "4");
-    stateStore.storeAssignedResources(containerId, "gpu", gpuRes1);
+    stateStore.storeAssignedResources(container, "gpu", gpuRes1);
     List<Serializable> fpgaRes = Arrays.asList("3", "4", "5", "6");
     List<Serializable> fpgaRes = Arrays.asList("3", "4", "5", "6");
-    stateStore.storeAssignedResources(containerId, "fpga", fpgaRes);
+    stateStore.storeAssignedResources(container, "fpga", fpgaRes);
     List<Serializable> numaRes = Arrays.asList("numa1");
     List<Serializable> numaRes = Arrays.asList("numa1");
-    stateStore.storeAssignedResources(containerId, "numa", numaRes);
+    stateStore.storeAssignedResources(container, "numa", numaRes);
 
 
     // add a invalid key
     // add a invalid key
     restartStateStore();
     restartStateStore();
@@ -1143,12 +1151,18 @@ public class TestNMLeveldbStateStoreService {
     List<Serializable> res = rcs.getResourceMappings()
     List<Serializable> res = rcs.getResourceMappings()
         .getAssignedResources("gpu");
         .getAssignedResources("gpu");
     Assert.assertTrue(res.equals(gpuRes1));
     Assert.assertTrue(res.equals(gpuRes1));
+    Assert.assertTrue(
+        resourceMappings.getAssignedResources("gpu").equals(gpuRes1));
 
 
     res = rcs.getResourceMappings().getAssignedResources("fpga");
     res = rcs.getResourceMappings().getAssignedResources("fpga");
     Assert.assertTrue(res.equals(fpgaRes));
     Assert.assertTrue(res.equals(fpgaRes));
+    Assert.assertTrue(
+        resourceMappings.getAssignedResources("fpga").equals(fpgaRes));
 
 
     res = rcs.getResourceMappings().getAssignedResources("numa");
     res = rcs.getResourceMappings().getAssignedResources("numa");
     Assert.assertTrue(res.equals(numaRes));
     Assert.assertTrue(res.equals(numaRes));
+    Assert.assertTrue(
+        resourceMappings.getAssignedResources("numa").equals(numaRes));
   }
   }
 
 
   private StartContainerRequest storeMockContainer(ContainerId containerId)
   private StartContainerRequest storeMockContainer(ContainerId containerId)