ソースを参照

YARN-9060. [YARN-8851] Phase 1 - Support device isolation and use the Nvidia GPU plugin as an example. Contributed by Zhankun Tang.

Sunil G 6 年 前
コミット
db4d1a1e2f
20 ファイル変更1970 行追加89 行削除
  1. 6 1
      hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg
  2. 2 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt
  3. 2 2
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java
  4. 1 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java
  5. 240 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.java
  6. 19 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/package-info.java
  7. 44 16
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java
  8. 22 3
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java
  9. 233 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceDockerRuntimePluginImpl.java
  10. 202 12
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java
  11. 46 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/ShellWrapper.java
  12. 6 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c
  13. 1 1
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c
  14. 281 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.c
  15. 45 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.h
  16. 3 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c
  17. 298 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/modules/devices/test-devices-module.cc
  18. 28 25
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java
  19. 383 29
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java
  20. 108 0
      hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java

+ 6 - 1
hadoop-yarn-project/hadoop-yarn/conf/container-executor.cfg

@@ -22,4 +22,9 @@ feature.tc.enabled=false
 #[fpga]
 #  module.enabled=## Enable/Disable the FPGA resource handler module. set to "true" to enable, disabled by default
 #  fpga.major-device-number=## Major device number of FPGA, by default is 246. Strongly recommend setting this
-#  fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
+#  fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
+
+# The configs below deal with settings for resource handled by pluggable device plugin framework
+#[devices]
+#  module.enabled=## Enable/Disable the device resource handler module for isolation. Disabled by default.
+#  devices.denied-numbers=## Blacklisted devices not permitted to use. The format is comma separated "majorNumber:minorNumber". For instance, "195:1,195:2". Leave it empty means default devices reported by device plugin are all allowed.

+ 2 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/CMakeLists.txt

@@ -135,6 +135,7 @@ add_library(container
     main/native/container-executor/impl/modules/common/module-configs.c
     main/native/container-executor/impl/modules/gpu/gpu-module.c
     main/native/container-executor/impl/modules/fpga/fpga-module.c
+    main/native/container-executor/impl/modules/devices/devices-module.c
     main/native/container-executor/impl/utils/docker-util.c
 )
 
@@ -169,6 +170,7 @@ add_executable(cetest
         main/native/container-executor/test/modules/cgroups/test-cgroups-module.cc
         main/native/container-executor/test/modules/gpu/test-gpu-module.cc
         main/native/container-executor/test/modules/fpga/test-fpga-module.cc
+        main/native/container-executor/test/modules/devices/test-devices-module.cc
         main/native/container-executor/test/test_util.cc
         main/native/container-executor/test/utils/test_docker_util.cc)
 target_link_libraries(cetest gtest container)

+ 2 - 2
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/api/deviceplugin/Device.java

@@ -181,8 +181,8 @@ public final class Device implements Serializable, Comparable {
     // default -1 representing the value is not set
     private int id = -1;
     private String devPath = "";
-    private int majorNumber;
-    private int minorNumber;
+    private int majorNumber = -1;
+    private int minorNumber = -1;
     private String busID = "";
     private boolean isHealthy;
     private String status = "";

+ 1 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/privileged/PrivilegedOperation.java

@@ -54,6 +54,7 @@ public class PrivilegedOperation {
     RUN_DOCKER_CMD("--run-docker"),
     GPU("--module-gpu"),
     FPGA("--module-fpga"),
+    DEVICE("--module-devices"),
     LIST_AS_USER(""), // no CLI switch supported yet.
     ADD_NUMA_PARAMS(""), // no CLI switch supported yet.
     REMOVE_DOCKER_CONTAINER("--remove-docker-container"),

+ 240 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.java

@@ -0,0 +1,240 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableSet;
+import org.apache.hadoop.util.Shell;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+/**
+ * Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
+ * non-Docker container.
+ * */
+public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
+  public static final Logger LOG = LoggerFactory.getLogger(
+      NvidiaGPUPluginForRuntimeV2.class);
+
+  public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";
+
+  private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();
+
+  private Map<String, String> environment = new HashMap<>();
+
+  // If this environment is set, use it directly
+  private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";
+
+  private static final String DEFAULT_BINARY_NAME = "nvidia-smi";
+
+  private static final String DEV_NAME_PREFIX = "nvidia";
+
+  private String pathOfGpuBinary = null;
+
+  // command should not run more than 10 sec.
+  private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000;
+
+  // When executable path not set, try to search default dirs
+  // By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when
+  // launched by nvidia-docker.
+  private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
+      "/usr/bin", "/bin", "/usr/local/nvidia/bin");
+
+  @Override
+  public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
+    return DeviceRegisterRequest.Builder.newInstance()
+        .setResourceName(NV_RESOURCE_NAME).build();
+  }
+
+  @Override
+  public Set<Device> getDevices() throws Exception {
+    shellExecutor.searchBinary();
+    TreeSet<Device> r = new TreeSet<>();
+    String output;
+    try {
+      output = shellExecutor.getDeviceInfo();
+      String[] lines = output.trim().split("\n");
+      int id = 0;
+      for (String oneLine : lines) {
+        String[] tokensEachLine = oneLine.split(",");
+        if (tokensEachLine.length != 2) {
+          throw new Exception("Cannot parse the output to get device info. "
+              + "Unexpected format in it:" + oneLine);
+        }
+        String minorNumber = tokensEachLine[0].trim();
+        String busId = tokensEachLine[1].trim();
+        String majorNumber = getMajorNumber(DEV_NAME_PREFIX
+            + minorNumber);
+        if (majorNumber != null) {
+          r.add(Device.Builder.newInstance()
+              .setId(id)
+              .setMajorNumber(Integer.parseInt(majorNumber))
+              .setMinorNumber(Integer.parseInt(minorNumber))
+              .setBusID(busId)
+              .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber)
+              .setHealthy(true)
+              .build());
+          id++;
+        }
+      }
+      return r;
+    } catch (IOException e) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Failed to get output from " + pathOfGpuBinary);
+      }
+      throw new YarnException(e);
+    }
+  }
+
+  @Override
+  public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
+      YarnRuntimeType yarnRuntime) throws Exception {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Generating runtime spec for allocated devices: "
+          + allocatedDevices + ", " + yarnRuntime.getName());
+    }
+    if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
+      String nvidiaRuntime = "nvidia";
+      String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
+      StringBuffer gpuMinorNumbersSB = new StringBuffer();
+      for (Device device : allocatedDevices) {
+        gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
+      }
+      String minorNumbers = gpuMinorNumbersSB.toString();
+      LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers);
+      return DeviceRuntimeSpec.Builder.newInstance()
+          .addEnv(nvidiaVisibleDevices,
+              minorNumbers.substring(0, minorNumbers.length() - 1))
+          .setContainerRuntime(nvidiaRuntime)
+          .build();
+    }
+    return null;
+  }
+
+  @Override
+  public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {
+    // do nothing
+  }
+
+  // Get major number from device name.
+  private String getMajorNumber(String devName) {
+    String output = null;
+    // output "major:minor" in hex
+    try {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Get major numbers from /dev/" + devName);
+      }
+      output = shellExecutor.getMajorMinorInfo(devName);
+      String[] strs = output.trim().split(":");
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("stat output:" + output);
+      }
+      output = Integer.toString(Integer.parseInt(strs[0], 16));
+    } catch (IOException e) {
+      String msg =
+          "Failed to get major number from reading /dev/" + devName;
+      LOG.warn(msg);
+    } catch (NumberFormatException e) {
+      LOG.error("Failed to parse device major number from stat output");
+      output = null;
+    }
+    return output;
+  }
+
+  /**
+   * A shell wrapper class easy for test.
+   * */
+  public class NvidiaCommandExecutor {
+
+    public String getDeviceInfo() throws IOException {
+      return Shell.execCommand(environment,
+          new String[]{pathOfGpuBinary, "--query-gpu=index,pci.bus_id",
+              "--format=csv,noheader"}, MAX_EXEC_TIMEOUT_MS);
+    }
+
+    public String getMajorMinorInfo(String devName) throws IOException {
+      // output "major:minor" in hex
+      Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
+          new String[]{"stat", "-c", "%t:%T", "/dev/" + devName});
+      shexec.execute();
+      return shexec.getOutput();
+    }
+
+    public void searchBinary() throws Exception {
+      if (pathOfGpuBinary != null) {
+        LOG.info("Skip searching, the nvidia gpu binary is already set: "
+            + pathOfGpuBinary);
+        return;
+      }
+      // search env for the binary
+      String envBinaryPath = System.getenv(ENV_BINARY_PATH);
+      if (null != envBinaryPath) {
+        if (new File(envBinaryPath).exists()) {
+          pathOfGpuBinary = envBinaryPath;
+          LOG.info("Use nvidia gpu binary: " + pathOfGpuBinary);
+          return;
+        }
+      }
+      LOG.info("Search binary..");
+      // search if binary exists in default folders
+      File binaryFile;
+      boolean found = false;
+      for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {
+        binaryFile = new File(dir, DEFAULT_BINARY_NAME);
+        if (binaryFile.exists()) {
+          found = true;
+          pathOfGpuBinary = binaryFile.getAbsolutePath();
+          LOG.info("Found binary:" + pathOfGpuBinary);
+          break;
+        }
+      }
+      if (!found) {
+        LOG.error("No binary found from env variable: "
+            + ENV_BINARY_PATH + " or path "
+            + DEFAULT_BINARY_SEARCH_DIRS.toString());
+        throw new Exception("No binary found for "
+            + NvidiaGPUPluginForRuntimeV2.class);
+      }
+    }
+  }
+
+  @VisibleForTesting
+  public void setPathOfGpuBinary(String pathOfGpuBinary) {
+    this.pathOfGpuBinary = pathOfGpuBinary;
+  }
+
+  @VisibleForTesting
+  public void setShellExecutor(
+      NvidiaCommandExecutor shellExecutor) {
+    this.shellExecutor = shellExecutor;
+  }
+}

+ 19 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/package-info.java

@@ -0,0 +1,19 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;

+ 44 - 16
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.java

@@ -95,6 +95,20 @@ public class DeviceMappingManager {
     return devicePluginSchedulers;
   }
 
+  @VisibleForTesting
+  public Set<Device> getAllocatedDevices(String resourceName,
+      ContainerId cId) {
+    Set<Device> assigned = new TreeSet<>();
+    Map<Device, ContainerId> assignedMap =
+        this.getAllUsedDevices().get(resourceName);
+    for (Map.Entry<Device, ContainerId> entry : assignedMap.entrySet()) {
+      if (entry.getValue().equals(cId)) {
+        assigned.add(entry.getKey());
+      }
+    }
+    return assigned;
+  }
+
   public synchronized void addDeviceSet(String resourceName,
       Set<Device> deviceSet) {
     LOG.info("Adding new resource: " + "type:"
@@ -148,8 +162,10 @@ public class DeviceMappingManager {
     ContainerId containerId = container.getContainerId();
     int requestedDeviceCount = getRequestedDeviceCount(resourceName,
         requestedResource);
-    LOG.debug("Try allocating " + requestedDeviceCount
-        + " " + resourceName);
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Try allocating " + requestedDeviceCount
+          + " " + resourceName);
+    }
     // Assign devices to container if requested some.
     if (requestedDeviceCount > 0) {
       if (requestedDeviceCount > getAvailableDevices(resourceName)) {
@@ -245,18 +261,24 @@ public class DeviceMappingManager {
       ContainerId containerId) {
     Iterator<Map.Entry<Device, ContainerId>> iter =
         allUsedDevices.get(resourceName).entrySet().iterator();
+    Map.Entry<Device, ContainerId> entry;
     while (iter.hasNext()) {
-      if (iter.next().getValue().equals(containerId)) {
+      entry = iter.next();
+      if (entry.getValue().equals(containerId)) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Recycle devices: " + entry.getKey()
+              + ", type: " + resourceName + " from " + containerId);
+        }
         iter.remove();
       }
     }
   }
 
-  public static int getRequestedDeviceCount(String resourceName,
+  public static int getRequestedDeviceCount(String resName,
       Resource requestedResource) {
     try {
       return Long.valueOf(requestedResource.getResourceValue(
-          resourceName)).intValue();
+          resName)).intValue();
     } catch (ResourceNotFoundException e) {
       return 0;
     }
@@ -270,10 +292,7 @@ public class DeviceMappingManager {
   private long getReleasingDevices(String resourceName) {
     long releasingDevices = 0;
     Map<Device, ContainerId> used = allUsedDevices.get(resourceName);
-    Iterator<Map.Entry<Device, ContainerId>> iter = used.entrySet()
-        .iterator();
-    while (iter.hasNext()) {
-      ContainerId containerId = iter.next().getValue();
+    for (ContainerId containerId : ImmutableSet.copyOf(used.values())) {
       Container container = nmContext.getContainers().get(containerId);
       if (container != null) {
         if (container.isContainerInFinalStates()) {
@@ -295,16 +314,20 @@ public class DeviceMappingManager {
       DevicePluginScheduler dps) throws ResourceHandlerException {
 
     if (null == dps) {
-      LOG.debug("Customized device plugin scheduler is preferred "
-          + "but not implemented, use default logic");
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Customized device plugin scheduler is preferred "
+            + "but not implemented, use default logic");
+      }
       defaultScheduleAction(allowed, used,
           assigned, containerId, count);
     } else {
-      LOG.debug("Customized device plugin implemented,"
-          + "use customized logic");
-      // Use customized device scheduler
-      LOG.debug("Try to schedule " + count
-          + "(" + resourceName + ") using " + dps.getClass());
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Customized device plugin implemented,"
+            + "use customized logic");
+        // Use customized device scheduler
+        LOG.debug("Try to schedule " + count
+            + "(" + resourceName + ") using " + dps.getClass());
+      }
       // Pass in unmodifiable set
       Set<Device> dpsAllocated = dps.allocateDevices(
           Sets.difference(allowed, used.keySet()),
@@ -345,6 +368,7 @@ public class DeviceMappingManager {
     private String resourceName;
 
     private Set<Device> allowed = Collections.emptySet();
+
     private Set<Device> denied = Collections.emptySet();
 
     DeviceAllocation(String resName, Set<Device> a,
@@ -362,6 +386,10 @@ public class DeviceMappingManager {
       return allowed;
     }
 
+    public Set<Device> getDenied() {
+      return denied;
+    }
+
     @Override
     public String toString() {
       return "ResourceType: " + resourceName

+ 22 - 3
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DevicePluginAdapter.java

@@ -18,6 +18,7 @@
 
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.yarn.api.records.ContainerId;
@@ -49,10 +50,20 @@ public class DevicePluginAdapter implements ResourcePlugin {
   private final static Log LOG = LogFactory.getLog(DevicePluginAdapter.class);
 
   private final String resourceName;
+
   private final DevicePlugin devicePlugin;
   private DeviceMappingManager deviceMappingManager;
+
   private DeviceResourceHandlerImpl deviceResourceHandler;
   private DeviceResourceUpdaterImpl deviceResourceUpdater;
+  private DeviceResourceDockerRuntimePluginImpl deviceDockerCommandPlugin;
+
+
+  @VisibleForTesting
+  public void setDeviceResourceHandler(
+      DeviceResourceHandlerImpl deviceResourceHandler) {
+    this.deviceResourceHandler = deviceResourceHandler;
+  }
 
   public DevicePluginAdapter(String name, DevicePlugin dp,
       DeviceMappingManager dmm) {
@@ -65,8 +76,16 @@ public class DevicePluginAdapter implements ResourcePlugin {
     return deviceMappingManager;
   }
 
+
+  public DevicePlugin getDevicePlugin() {
+    return devicePlugin;
+  }
+
   @Override
   public void initialize(Context context) throws YarnException {
+    deviceDockerCommandPlugin = new DeviceResourceDockerRuntimePluginImpl(
+        resourceName,
+        devicePlugin, this);
     deviceResourceUpdater = new DeviceResourceUpdaterImpl(
         resourceName, devicePlugin);
     LOG.info(resourceName + " plugin adapter initialized");
@@ -78,8 +97,8 @@ public class DevicePluginAdapter implements ResourcePlugin {
       CGroupsHandler cGroupsHandler,
       PrivilegedOperationExecutor privilegedOperationExecutor) {
     this.deviceResourceHandler = new DeviceResourceHandlerImpl(resourceName,
-        devicePlugin, this, deviceMappingManager,
-        cGroupsHandler, privilegedOperationExecutor);
+        this, deviceMappingManager,
+        cGroupsHandler, privilegedOperationExecutor, nmContext);
     return deviceResourceHandler;
   }
 
@@ -95,7 +114,7 @@ public class DevicePluginAdapter implements ResourcePlugin {
 
   @Override
   public DockerCommandPlugin getDockerCommandPluginInstance() {
-    return null;
+    return deviceDockerCommandPlugin;
   }
 
   @Override

+ 233 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceDockerRuntimePluginImpl.java

@@ -0,0 +1,233 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
+import org.apache.hadoop.yarn.util.LRUCacheHashMap;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Bridge DevicePlugin and the hooks related to lunch Docker container.
+ * When launching Docker container, DockerLinuxContainerRuntime will invoke
+ * this class's methods which get needed info back from DevicePlugin.
+ * */
+public class DeviceResourceDockerRuntimePluginImpl
+    implements DockerCommandPlugin {
+
+  final static Log LOG = LogFactory.getLog(
+      DeviceResourceDockerRuntimePluginImpl.class);
+
+  private String resourceName;
+  private DevicePlugin devicePlugin;
+  private DevicePluginAdapter devicePluginAdapter;
+
+  private int maxCacheSize = 100;
+  // LRU to avoid memory leak if getCleanupDockerVolumesCommand not invoked.
+  private Map<ContainerId, Set<Device>> cachedAllocation =
+      Collections.synchronizedMap(new LRUCacheHashMap(maxCacheSize, true));
+
+  private Map<ContainerId, DeviceRuntimeSpec> cachedSpec =
+      Collections.synchronizedMap(new LRUCacheHashMap<>(maxCacheSize, true));
+
+  public DeviceResourceDockerRuntimePluginImpl(String resourceName,
+      DevicePlugin devicePlugin, DevicePluginAdapter devicePluginAdapter) {
+    this.resourceName = resourceName;
+    this.devicePlugin = devicePlugin;
+    this.devicePluginAdapter = devicePluginAdapter;
+  }
+
+  @Override
+  public void updateDockerRunCommand(DockerRunCommand dockerRunCommand,
+      Container container) throws ContainerExecutionException {
+    String containerId = container.getContainerId().toString();
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Try to update docker run command for: " + containerId);
+    }
+    if(!requestedDevice(resourceName, container)) {
+      return;
+    }
+    DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
+    if (deviceRuntimeSpec == null) {
+      LOG.warn("The device plugin: "
+          + devicePlugin.getClass().getCanonicalName()
+          + " returns null device runtime spec value for container: "
+          + containerId);
+      return;
+    }
+    // handle runtime
+    dockerRunCommand.addRuntime(deviceRuntimeSpec.getContainerRuntime());
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Handle docker container runtime type: "
+          + deviceRuntimeSpec.getContainerRuntime() + " for container: "
+          + containerId);
+    }
+    // handle device mounts
+    Set<MountDeviceSpec> deviceMounts = deviceRuntimeSpec.getDeviceMounts();
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Handle device mounts: " + deviceMounts + " for container: "
+          + containerId);
+    }
+    for (MountDeviceSpec mountDeviceSpec : deviceMounts) {
+      dockerRunCommand.addDevice(
+          mountDeviceSpec.getDevicePathInHost(),
+          mountDeviceSpec.getDevicePathInContainer());
+    }
+    // handle volume mounts
+    Set<MountVolumeSpec> mountVolumeSpecs = deviceRuntimeSpec.getVolumeMounts();
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Handle volume mounts: " + mountVolumeSpecs + " for container: "
+          + containerId);
+    }
+    for (MountVolumeSpec mountVolumeSpec : mountVolumeSpecs) {
+      if (mountVolumeSpec.getReadOnly()) {
+        dockerRunCommand.addReadOnlyMountLocation(
+            mountVolumeSpec.getHostPath(),
+            mountVolumeSpec.getMountPath());
+      } else {
+        dockerRunCommand.addReadWriteMountLocation(
+            mountVolumeSpec.getHostPath(),
+            mountVolumeSpec.getMountPath());
+      }
+    }
+    // handle envs
+    dockerRunCommand.addEnv(deviceRuntimeSpec.getEnvs());
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Handle envs: " + deviceRuntimeSpec.getEnvs()
+          + " for container: " + containerId);
+    }
+  }
+
+  @Override
+  public DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
+      throws ContainerExecutionException {
+    if(!requestedDevice(resourceName, container)) {
+      return null;
+    }
+    DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
+    if (deviceRuntimeSpec == null) {
+      return null;
+    }
+    Set<VolumeSpec> volumeClaims = deviceRuntimeSpec.getVolumeSpecs();
+    for (VolumeSpec volumeSec: volumeClaims) {
+      if (volumeSec.getVolumeOperation().equals(VolumeSpec.CREATE)) {
+        DockerVolumeCommand command = new DockerVolumeCommand(
+            DockerVolumeCommand.VOLUME_CREATE_SUB_COMMAND);
+        command.setDriverName(volumeSec.getVolumeDriver());
+        command.setVolumeName(volumeSec.getVolumeName());
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Get volume create request from plugin:" + volumeClaims
+              + " for container: " + container.getContainerId().toString());
+        }
+        return command;
+      }
+    }
+    return null;
+  }
+
+  @Override
+  public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
+      throws ContainerExecutionException {
+
+    if(!requestedDevice(resourceName, container)) {
+      return null;
+    }
+    Set<Device> allocated = getAllocatedDevices(container);
+    try {
+      devicePlugin.onDevicesReleased(allocated);
+    } catch (Exception e) {
+      LOG.warn("Exception thrown in onDeviceReleased of "
+          + devicePlugin.getClass() + "for container: "
+          + container.getContainerId().toString(), e);
+    }
+    // remove cache
+    ContainerId containerId = container.getContainerId();
+    cachedAllocation.remove(containerId);
+    cachedSpec.remove(containerId);
+    return null;
+  }
+
+  protected boolean requestedDevice(String resName, Container container) {
+    return DeviceMappingManager.
+        getRequestedDeviceCount(resName, container.getResource()) > 0;
+  }
+
+  private Set<Device> getAllocatedDevices(Container container) {
+    // get allocated devices
+    Set<Device> allocated;
+    ContainerId containerId = container.getContainerId();
+    allocated = cachedAllocation.get(containerId);
+    if (allocated != null) {
+      return allocated;
+    }
+    allocated = devicePluginAdapter
+        .getDeviceMappingManager()
+        .getAllocatedDevices(resourceName, containerId);
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Get allocation from deviceMappingManager: "
+          + allocated + ", " + resourceName + " for container: " + containerId);
+    }
+    cachedAllocation.put(containerId, allocated);
+    return allocated;
+  }
+
+  public synchronized DeviceRuntimeSpec getRuntimeSpec(Container container) {
+    ContainerId containerId = container.getContainerId();
+    DeviceRuntimeSpec deviceRuntimeSpec = cachedSpec.get(containerId);
+    if (deviceRuntimeSpec == null) {
+      Set<Device> allocated = getAllocatedDevices(container);
+      if (allocated == null || allocated.size() == 0) {
+        LOG.error("Cannot get allocation for container:" + containerId);
+        return null;
+      }
+      try {
+        deviceRuntimeSpec = devicePlugin.onDevicesAllocated(allocated,
+            YarnRuntimeType.RUNTIME_DOCKER);
+      } catch (Exception e) {
+        LOG.error("Exception thrown in onDeviceAllocated of "
+            + devicePlugin.getClass() + " for container: " + containerId, e);
+      }
+      if (deviceRuntimeSpec == null) {
+        LOG.error("Null DeviceRuntimeSpec value got from "
+            + devicePlugin.getClass() + " for container: "
+            + containerId + ", please check plugin logic");
+        return null;
+      }
+      cachedSpec.put(containerId, deviceRuntimeSpec);
+    }
+    return deviceRuntimeSpec;
+  }
+
+}

+ 202 - 12
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceResourceHandlerImpl.java

@@ -18,20 +18,29 @@
 
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.DockerLinuxContainerRuntime;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Set;
 
@@ -52,19 +61,45 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
   private final CGroupsHandler cGroupsHandler;
   private final PrivilegedOperationExecutor privilegedOperationExecutor;
   private final DevicePluginAdapter devicePluginAdapter;
+  private final Context nmContext;
+  private ShellWrapper shellWrapper;
 
-  public DeviceResourceHandlerImpl(String reseName,
-      DevicePlugin devPlugin,
+  // This will be used by container-executor to add necessary clis
+  public static final String EXCLUDED_DEVICES_CLI_OPTION = "--excluded_devices";
+  public static final String ALLOWED_DEVICES_CLI_OPTION = "--allowed_devices";
+  public static final String CONTAINER_ID_CLI_OPTION = "--container_id";
+
+  public DeviceResourceHandlerImpl(String resName,
+      DevicePluginAdapter devPluginAdapter,
+      DeviceMappingManager devMappingManager,
+      CGroupsHandler cgHandler,
+      PrivilegedOperationExecutor operation,
+      Context ctx) {
+    this.devicePluginAdapter = devPluginAdapter;
+    this.resourceName = resName;
+    this.devicePlugin = devPluginAdapter.getDevicePlugin();
+    this.cGroupsHandler = cgHandler;
+    this.privilegedOperationExecutor = operation;
+    this.deviceMappingManager = devMappingManager;
+    this.nmContext = ctx;
+    this.shellWrapper = new ShellWrapper();
+  }
+
+  @VisibleForTesting
+  public DeviceResourceHandlerImpl(String resName,
       DevicePluginAdapter devPluginAdapter,
       DeviceMappingManager devMappingManager,
       CGroupsHandler cgHandler,
-      PrivilegedOperationExecutor operation) {
+      PrivilegedOperationExecutor operation,
+      Context ctx, ShellWrapper shell) {
     this.devicePluginAdapter = devPluginAdapter;
-    this.resourceName = reseName;
-    this.devicePlugin = devPlugin;
+    this.resourceName = resName;
+    this.devicePlugin = devPluginAdapter.getDevicePlugin();
     this.cGroupsHandler = cgHandler;
     this.privilegedOperationExecutor = operation;
     this.deviceMappingManager = devMappingManager;
+    this.nmContext = ctx;
+    this.shellWrapper = shell;
   }
 
   @Override
@@ -98,11 +133,13 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
     String containerIdStr = container.getContainerId().toString();
     DeviceMappingManager.DeviceAllocation allocation =
         deviceMappingManager.assignDevices(resourceName, container);
-    LOG.debug("Allocated to "
-        + containerIdStr + ": " + allocation);
-
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Allocated to "
+          + containerIdStr + ": " + allocation);
+    }
+    DeviceRuntimeSpec spec;
     try {
-      devicePlugin.onDevicesAllocated(
+      spec = devicePlugin.onDevicesAllocated(
           allocation.getAllowed(), YarnRuntimeType.RUNTIME_DEFAULT);
     } catch (Exception e) {
       throw new ResourceHandlerException("Exception thrown from"
@@ -110,13 +147,95 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
     }
 
     // cgroups operation based on allocation
-    /**
-     * TODO: implement a general container-executor device module
-     * */
+    if (spec != null) {
+      LOG.warn("Runtime spec in non-Docker container is not supported yet!");
+    }
+    // Create device cgroups for the container
+    cGroupsHandler.createCGroup(CGroupsHandler.CGroupController.DEVICES,
+        containerIdStr);
+    // non-Docker, use cgroups to do isolation
+    if (!DockerLinuxContainerRuntime.isDockerContainerRequested(
+        nmContext.getConf(),
+        container.getLaunchContext().getEnvironment())) {
+      tryIsolateDevices(allocation, containerIdStr);
+      List<PrivilegedOperation> ret = new ArrayList<>();
+      ret.add(new PrivilegedOperation(
+          PrivilegedOperation.OperationType.ADD_PID_TO_CGROUP,
+          PrivilegedOperation.CGROUP_ARG_PREFIX + cGroupsHandler
+              .getPathForCGroupTasks(CGroupsHandler.CGroupController.DEVICES,
+                  containerIdStr)));
 
+      return ret;
+    }
     return null;
   }
 
+  /**
+   * Try set cgroup devices params for the container using container-executor.
+   * If it has real device major number, minor number or dev path,
+   * we'll do the enforcement. Otherwise, won't do it.
+   *
+   * */
+  private void tryIsolateDevices(
+      DeviceMappingManager.DeviceAllocation allocation,
+      String containerIdStr) throws ResourceHandlerException {
+    try {
+      // Execute c-e to setup device isolation before launch the container
+      PrivilegedOperation privilegedOperation = new PrivilegedOperation(
+          PrivilegedOperation.OperationType.DEVICE,
+          Arrays.asList(CONTAINER_ID_CLI_OPTION, containerIdStr));
+      boolean needNativeDeviceOperation = false;
+      int majorNumber;
+      int minorNumber;
+      List<String> devNumbers = new ArrayList<>();
+      if (!allocation.getDenied().isEmpty()) {
+        DeviceType devType;
+        for (Device deniedDevice : allocation.getDenied()) {
+          majorNumber = deniedDevice.getMajorNumber();
+          minorNumber = deniedDevice.getMinorNumber();
+          // Add device type
+          devType = getDeviceType(deniedDevice);
+          if (devType != null) {
+            devNumbers.add(devType.getName() + "-" + majorNumber + ":"
+                + minorNumber + "-rwm");
+          }
+        }
+        if (devNumbers.size() != 0) {
+          privilegedOperation.appendArgs(
+              Arrays.asList(EXCLUDED_DEVICES_CLI_OPTION,
+                  StringUtils.join(",", devNumbers)));
+          needNativeDeviceOperation = true;
+        }
+      }
+
+      if (!allocation.getAllowed().isEmpty()) {
+        devNumbers.clear();
+        for (Device allowedDevice : allocation.getAllowed()) {
+          majorNumber = allowedDevice.getMajorNumber();
+          minorNumber = allowedDevice.getMinorNumber();
+          if (majorNumber != -1 && minorNumber != -1) {
+            devNumbers.add(majorNumber + ":" + minorNumber);
+          }
+        }
+        if (devNumbers.size() > 0) {
+          privilegedOperation.appendArgs(
+              Arrays.asList(ALLOWED_DEVICES_CLI_OPTION,
+                  StringUtils.join(",", devNumbers)));
+          needNativeDeviceOperation = true;
+        }
+      }
+      if (needNativeDeviceOperation) {
+        privilegedOperationExecutor.executePrivilegedOperation(
+            privilegedOperation, true);
+      }
+    } catch (PrivilegedOperationException e) {
+      cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
+          containerIdStr);
+      LOG.warn("Could not update cgroup for container", e);
+      throw new ResourceHandlerException(e);
+    }
+  }
+
   @Override
   public synchronized List<PrivilegedOperation> reacquireContainer(
       ContainerId containerId) throws ResourceHandlerException {
@@ -134,6 +253,8 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
   public synchronized List<PrivilegedOperation> postComplete(
       ContainerId containerId) throws ResourceHandlerException {
     deviceMappingManager.cleanupAssignedDevices(resourceName, containerId);
+    cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
+        containerId.toString());
     return null;
   }
 
@@ -151,4 +272,73 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
         ", devicePluginAdapter=" + devicePluginAdapter +
         '}';
   }
+
+  public DeviceType getDeviceType(Device device) {
+    String devName = device.getDevPath();
+    if (devName.isEmpty()) {
+      LOG.warn("Empty device path provided, try to get device type from " +
+          "major:minor device number");
+      int major = device.getMajorNumber();
+      int minor = device.getMinorNumber();
+      if (major == -1 && minor == -1) {
+        LOG.warn("Non device number provided, cannot decide the device type");
+        return null;
+      }
+      // Get type from the device numbers
+      return getDeviceTypeFromDeviceNumber(device.getMajorNumber(),
+          device.getMinorNumber());
+    }
+    DeviceType deviceType;
+    try {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Try to get device type from device path: " + devName);
+      }
+      String output = shellWrapper.getDeviceFileType(devName);
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("stat output:" + output);
+      }
+      deviceType = output.startsWith("c") ? DeviceType.CHAR : DeviceType.BLOCK;
+    } catch (IOException e) {
+      String msg =
+          "Failed to get device type from stat " + devName;
+      LOG.warn(msg);
+      return null;
+    }
+    return deviceType;
+  }
+
+  /**
+   * Get the device type used for cgroups value set.
+   * If sys file "/sys/dev/block/major:minor" exists, it's block device.
+   * Otherwise, it's char device. An exception is that Nvidia GPU doesn't
+   * create this sys file. so assume character device by default.
+   */
+  public DeviceType getDeviceTypeFromDeviceNumber(int major, int minor) {
+    if (shellWrapper.existFile("/sys/dev/block/"
+        + major + ":" + minor)) {
+      return DeviceType.BLOCK;
+    }
+    return DeviceType.CHAR;
+  }
+
+  /**
+   * Enum for Linux device type. Used when updating device cgroups params.
+   * "b" represents block device
+   * "c" represents character device
+   * */
+  private enum DeviceType {
+    BLOCK("b"),
+    CHAR("c");
+
+    private final String name;
+
+    DeviceType(String n) {
+      this.name = n;
+    }
+
+    public String getName() {
+      return name;
+    }
+  }
+
 }

+ 46 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/ShellWrapper.java

@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
+
+import org.apache.hadoop.util.Shell;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * A shell Wrapper to ease testing.
+ * */
+public class ShellWrapper {
+
+  public String getDeviceFileType(String devName) throws IOException {
+    Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
+        new String[]{"stat", "-c", "%F", devName});
+    shexec.execute();
+    return shexec.getOutput();
+  }
+
+  public boolean existFile(String path) {
+    File searchFile =
+        new File(path);
+    if (searchFile.exists()) {
+      return true;
+    }
+    return false;
+  }
+}

+ 6 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/main.c

@@ -24,6 +24,7 @@
 #include "modules/gpu/gpu-module.h"
 #include "modules/fpga/fpga-module.h"
 #include "modules/cgroups/cgroups-operations.h"
+#include "modules/devices/devices-module.h"
 #include "utils/string-utils.h"
 
 #include <errno.h>
@@ -289,6 +290,11 @@ static int validate_arguments(int argc, char **argv , int *operation) {
            &argv[1]);
   }
 
+  if (strcmp("--module-devices", argv[1]) == 0) {
+    return handle_devices_request(&update_cgroups_parameters, "devices", argc - 1,
+          &argv[1]);
+  }
+
   if (strcmp("--checksetup", argv[1]) == 0) {
     *operation = CHECK_SETUP;
     return 0;

+ 1 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/cgroups/cgroups-operations.c

@@ -132,7 +132,7 @@ int update_cgroups_parameters(
     goto cleanup;
   }
 
-  fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s",
+  fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s\n",
     full_path, value);
 
   // Write values to file

+ 281 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.c

@@ -0,0 +1,281 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "configuration.h"
+#include "container-executor.h"
+#include "utils/string-utils.h"
+#include "modules/devices/devices-module.h"
+#include "modules/cgroups/cgroups-operations.h"
+#include "modules/common/module-configs.h"
+#include "modules/common/constants.h"
+#include "util.h"
+
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include <getopt.h>
+#include <unistd.h>
+#include <sys/stat.h>
+
+#define EXCLUDED_DEVICES_OPTION "excluded_devices"
+#define ALLOWED_DEVICES_OPTION "allowed_devices"
+#define CONTAINER_ID_OPTION "container_id"
+#define MAX_CONTAINER_ID_LEN 128
+
+static const struct section* cfg_section;
+
+// Search a string in a string list, return 1 when found
+static int search_in_list(char** list, char* token) {
+  int i = 0;
+  char** iterator = list;
+  // search token in  list
+  while (iterator[i] != NULL) {
+    if (strstr(token, iterator[i]) != NULL ||
+        strstr(iterator[i], token) != NULL) {
+      // Found deny device in allowed list
+      return 1;
+    }
+    i++;
+  }
+  return 0;
+}
+
+static int is_block_device(const char* value) {
+  int is_block = 0;
+  int max_path_size = 512;
+  char* block_path = malloc(max_path_size);
+  if (block_path == NULL) {
+    fprintf(ERRORFILE, "Failed to allocate memory for sys device path string.\n");
+    fflush(ERRORFILE);
+    goto cleanup;
+  }
+  if (snprintf(block_path, max_path_size, "/sys/dev/block/%s",
+    value) < 0) {
+    fprintf(ERRORFILE, "Failed to construct system block device path.\n");
+    goto cleanup;
+  }
+  struct stat sb;
+  // file exists, is block device
+  if (stat(block_path, &sb) == 0) {
+    is_block = 1;
+  }
+cleanup:
+  if (block_path) {
+    free(block_path);
+  }
+  return is_block;
+}
+
+static int internal_handle_devices_request(
+    update_cgroups_parameters_function update_cgroups_parameters_func_p,
+    char** deny_devices_number_tokens,
+    char** allow_devices_number_tokens,
+    const char* container_id) {
+  int return_code = 0;
+
+  char** ce_denied_numbers = NULL;
+  char* ce_denied_str = get_section_value(DEVICES_DENIED_NUMBERS,
+     cfg_section);
+  // Get denied "major:minor" device numbers from cfg, if not set, means all
+  // devices can be used by YARN.
+  if (ce_denied_str != NULL) {
+    ce_denied_numbers = split_delimiter(ce_denied_str, ",");
+    if (NULL == ce_denied_numbers) {
+      fprintf(ERRORFILE,
+          "Invalid value set for %s, value=%s\n",
+          DEVICES_DENIED_NUMBERS,
+          ce_denied_str);
+      return_code = -1;
+      goto cleanup;
+    }
+    // Check allowed devices passed in
+    char** allow_iterator = allow_devices_number_tokens;
+    int allow_count = 0;
+    while (allow_iterator[allow_count] != NULL) {
+      if (search_in_list(ce_denied_numbers, allow_iterator[allow_count])) {
+        fprintf(ERRORFILE,
+          "Trying to allow device with device number=%s which is not permitted in container-executor.cfg. %s\n",
+          allow_iterator[allow_count],
+          "It could be caused by a mismatch of devices reported by device plugin");
+        return_code = -1;
+        goto cleanup;
+      }
+      allow_count++;
+    }
+
+    // Deny devices configured in c-e.cfg
+    char** ce_iterator = ce_denied_numbers;
+    int ce_count = 0;
+    while (ce_iterator[ce_count] != NULL) {
+      // skip if duplicate with denied numbers passed in
+      if (search_in_list(deny_devices_number_tokens, ce_iterator[ce_count])) {
+        ce_count++;
+        continue;
+      }
+      char param_value[128];
+      char type = 'c';
+      memset(param_value, 0, sizeof(param_value));
+      if (is_block_device(ce_iterator[ce_count])) {
+        type = 'b';
+      }
+      snprintf(param_value, sizeof(param_value), "%c %s rwm",
+               type,
+               ce_iterator[ce_count]);
+      // Update device cgroups value
+      int rc = update_cgroups_parameters_func_p("devices", "deny",
+        container_id, param_value);
+
+      if (0 != rc) {
+        fprintf(ERRORFILE, "CGroups: Failed to update cgroups. %s\n", param_value);
+        return_code = -1;
+        goto cleanup;
+      }
+      ce_count++;
+    }
+  }
+
+  // Deny devices passed from java side
+  char** iterator = deny_devices_number_tokens;
+  int count = 0;
+  char* value = NULL;
+  int index = 0;
+  while (iterator[count] != NULL) {
+    // Replace like "c-242:0-rwm" to "c 242:0 rwm"
+    value = iterator[count];
+    index = 0;
+    while (value[index] != '\0') {
+      if (value[index] == '-') {
+        value[index] = ' ';
+      }
+      index++;
+    }
+    // Update device cgroups value
+    int rc = update_cgroups_parameters_func_p("devices", "deny",
+      container_id, iterator[count]);
+
+    if (0 != rc) {
+      fprintf(ERRORFILE, "CGroups: Failed to update cgroups\n");
+      return_code = -1;
+      goto cleanup;
+    }
+    count++;
+  }
+
+cleanup:
+  if (ce_denied_numbers != NULL) {
+    free_values(ce_denied_numbers);
+  }
+  return return_code;
+}
+
+void reload_devices_configuration() {
+  cfg_section = get_configuration_section(DEVICES_MODULE_SECTION_NAME, get_cfg());
+}
+
+/*
+ * Format of devices request commandline:
+ * The excluded_devices is comma separated device cgroups values with device type.
+ * The "-" will be replaced with " " to match the cgroups parameter
+ * c-e --module-devices \
+ * --excluded_devices b-8:16-rwm,c-244:0-rwm,c-244:1-rwm \
+ * --allowed_devices 8:32,8:48,243:2 \
+ * --container_id container_x_y
+ */
+int handle_devices_request(update_cgroups_parameters_function func,
+    const char* module_name, int module_argc, char** module_argv) {
+  if (!cfg_section) {
+    reload_devices_configuration();
+  }
+
+  if (!module_enabled(cfg_section, DEVICES_MODULE_SECTION_NAME)) {
+    fprintf(ERRORFILE,
+      "Please make sure devices module is enabled before using it.\n");
+    return -1;
+  }
+
+  static struct option long_options[] = {
+    {EXCLUDED_DEVICES_OPTION, required_argument, 0, 'e' },
+    {ALLOWED_DEVICES_OPTION, required_argument, 0, 'a' },
+    {CONTAINER_ID_OPTION, required_argument, 0, 'c' },
+    {0, 0, 0, 0}
+  };
+
+  int c = 0;
+  int option_index = 0;
+
+  char** deny_device_value_tokens = NULL;
+  char** allow_device_value_tokens = NULL;
+  char container_id[MAX_CONTAINER_ID_LEN];
+  memset(container_id, 0, sizeof(container_id));
+  int failed = 0;
+
+  optind = 1;
+  while((c = getopt_long(module_argc, module_argv, "e:a:c:",
+                         long_options, &option_index)) != -1) {
+    switch(c) {
+      case 'e':
+        deny_device_value_tokens = split_delimiter(optarg, ",");
+        break;
+      case 'a':
+        allow_device_value_tokens = split_delimiter(optarg, ",");
+        break;
+      case 'c':
+        if (!validate_container_id(optarg)) {
+          fprintf(ERRORFILE,
+            "Specified container_id=%s is invalid\n", optarg);
+          failed = 1;
+          goto cleanup;
+        }
+        strncpy(container_id, optarg, MAX_CONTAINER_ID_LEN);
+        break;
+      default:
+        fprintf(ERRORFILE,
+          "Unknown option in devices command character %d %c, optionindex = %d\n",
+          c, c, optind);
+        failed = 1;
+        goto cleanup;
+    }
+  }
+
+  if (0 == container_id[0]) {
+    fprintf(ERRORFILE,
+      "[%s] --container_id must be specified.\n", __func__);
+    failed = 1;
+    goto cleanup;
+  }
+
+  if (NULL == deny_device_value_tokens) {
+     // Devices number is null, skip following call.
+     fprintf(ERRORFILE, "--excluded_devices is not specified, skip cgroups call.\n");
+     goto cleanup;
+  }
+
+  failed = internal_handle_devices_request(func,
+         deny_device_value_tokens,
+         allow_device_value_tokens,
+         container_id);
+
+cleanup:
+  if (deny_device_value_tokens) {
+    free_values(deny_device_value_tokens);
+  }
+  if (allow_device_value_tokens) {
+    free_values(allow_device_value_tokens);
+  }
+  return failed;
+}

+ 45 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/modules/devices/devices-module.h

@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifdef __FreeBSD__
+#define _WITH_GETLINE
+#endif
+
+#ifndef _MODULES_DEVICES_MUDULE_H_
+#define _MODULES_DEVICES_MUDULE_H_
+
+// Denied device list. value format is "major1:minor1,major2:minor2"
+#define DEVICES_DENIED_NUMBERS "devices.denied-numbers"
+#define DEVICES_MODULE_SECTION_NAME "devices"
+
+// For unit test stubbing
+typedef int (*update_cgroups_parameters_function)(const char*, const char*,
+   const char*, const char*);
+
+/**
+ * Handle devices requests
+ */
+int handle_devices_request(update_cgroups_parameters_function func,
+   const char* module_name, int module_argc, char** module_argv);
+
+/**
+ * Reload config from filesystem, visible for testing.
+ */
+void reload_devices_configuration();
+
+#endif

+ 3 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/impl/util.c

@@ -44,6 +44,9 @@ char** split_delimiter(char *value, const char *delim) {
     memset(return_values, 0, sizeof(char *) * return_values_size);
 
     temp_tok = strtok_r(value, delim, &tempstr);
+    if (NULL == temp_tok) {
+      return_values[size++] = strdup(value);
+    }
     while (temp_tok != NULL) {
       temp_tok = strdup(temp_tok);
       if (NULL == temp_tok) {

+ 298 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/native/container-executor/test/modules/devices/test-devices-module.cc

@@ -0,0 +1,298 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <vector>
+
+#include <errno.h>
+#include <fcntl.h>
+#include <inttypes.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include <gtest/gtest.h>
+#include <sstream>
+
+extern "C" {
+#include "configuration.h"
+#include "container-executor.h"
+#include "modules/cgroups/cgroups-operations.h"
+#include "modules/devices/devices-module.h"
+#include "test/test-container-executor-common.h"
+#include "util.h"
+}
+
+namespace ContainerExecutor {
+
+class TestDevicesModule : public ::testing::Test {
+protected:
+  virtual void SetUp() {
+    if (mkdirs(TEST_ROOT, 0755) != 0) {
+      fprintf(ERRORFILE, "Failed to mkdir TEST_ROOT: %s\n", TEST_ROOT);
+      exit(1);
+    }
+    LOGFILE = stdout;
+    ERRORFILE = stderr;
+  }
+
+  virtual void TearDown() {
+
+  }
+};
+
+static std::vector<const char*> cgroups_parameters_invoked;
+
+static int mock_update_cgroups_parameters(
+   const char* controller_name,
+   const char* param_name,
+   const char* group_id,
+   const char* value) {
+  char* buf = (char*) malloc(128);
+  strcpy(buf, controller_name);
+  cgroups_parameters_invoked.push_back(buf);
+
+  buf = (char*) malloc(128);
+  strcpy(buf, param_name);
+  cgroups_parameters_invoked.push_back(buf);
+
+  buf = (char*) malloc(128);
+  strcpy(buf, group_id);
+  cgroups_parameters_invoked.push_back(buf);
+
+  buf = (char*) malloc(128);
+  strcpy(buf, value);
+  cgroups_parameters_invoked.push_back(buf);
+  return 0;
+}
+
+static void clear_cgroups_parameters_invoked() {
+  for (std::vector<const char*>::size_type i = 0; i < cgroups_parameters_invoked.size(); i++) {
+    free((void *) cgroups_parameters_invoked[i]);
+  }
+  cgroups_parameters_invoked.clear();
+}
+
+static void verify_param_updated_to_cgroups(
+    int argc, const char** argv) {
+  ASSERT_EQ(argc, cgroups_parameters_invoked.size());
+
+  int offset = 0;
+  while (offset < argc) {
+    ASSERT_STREQ(argv[offset], cgroups_parameters_invoked[offset]);
+    offset++;
+  }
+}
+
+static void write_and_load_devices_module_to_cfg(const char* cfg_filepath, int enabled) {
+  FILE *file = fopen(cfg_filepath, "w");
+  if (file == NULL) {
+    printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
+    exit(1);
+  }
+  fprintf(file, "[devices]\n");
+  if (enabled) {
+    fprintf(file, "module.enabled=true\n");
+  } else {
+    fprintf(file, "module.enabled=false\n");
+  }
+  fclose(file);
+
+  // Read config file
+  read_executor_config(cfg_filepath);
+  reload_devices_configuration();
+}
+
+static void append_config(const char* cfg_filepath, char values[]) {
+  FILE *file = fopen(cfg_filepath, "a");
+  if (file == NULL) {
+    printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
+    exit(1);
+  }
+  fprintf(file, "%s", values);
+  fclose(file);
+
+  // Read config file
+  read_executor_config(cfg_filepath);
+  reload_devices_configuration();
+}
+
+static void test_devices_module_enabled_disabled(int enabled) {
+  // Write config file.
+  const char *filename = TEST_ROOT "/test_cgroups_module_enabled_disabled.cfg";
+  write_and_load_devices_module_to_cfg(filename, enabled);
+  char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+  char allowed_devices[] = "243:2";
+  char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices,
+                   (char*) "--allowed_devices",
+                   allowed_devices,
+                   (char*) "--container_id",
+                   (char*) "container_1498064906505_0001_01_000001" };
+
+  int rc = handle_devices_request(&mock_update_cgroups_parameters,
+              "devices", 7, argv);
+
+  int EXPECTED_RC;
+  if (enabled) {
+    EXPECTED_RC = 0;
+  } else {
+    EXPECTED_RC = -1;
+  }
+  ASSERT_EQ(EXPECTED_RC, rc);
+
+  clear_cgroups_parameters_invoked();
+  free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_verify_device_module_calls_cgroup_parameter) {
+  // Write config file.
+  const char *filename = TEST_ROOT "/test_verify_devices_module_calls_cgroup_parameter.cfg";
+  write_and_load_devices_module_to_cfg(filename, 1);
+
+  char* container_id = (char*) "container_1498064906505_0001_01_000001";
+  char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+  char allowed_devices[] = "243:2";
+  char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices,
+                   (char*) "--allowed_devices",
+                   allowed_devices,
+                   (char*) "--container_id",
+                   container_id };
+  /* Test case 1: block 2 devices */
+  clear_cgroups_parameters_invoked();
+  int rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 7, argv);
+  ASSERT_EQ(0, rc) << "Should success.\n";
+  // Verify cgroups parameters
+  const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
+    "devices", "deny", container_id, "c 243:1 rwm"};
+  verify_param_updated_to_cgroups(8, expected_cgroups_argv);
+
+  /* Test case 2: block 0 devices */
+  clear_cgroups_parameters_invoked();
+  char* argv_1[] = { (char*) "--module-devices", (char*) "--container_id", container_id };
+  rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 3, argv_1);
+  ASSERT_EQ(0, rc) << "Should success.\n";
+
+  // Verify cgroups parameters
+  verify_param_updated_to_cgroups(0, NULL);
+
+  clear_cgroups_parameters_invoked();
+  free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_update_cgroup_parameter_with_config) {
+  // Write config file.
+  const char *filename = TEST_ROOT "/test_update_cgroup_parameter_with_config.cfg";
+  write_and_load_devices_module_to_cfg(filename, 1);
+  // Add denied numbers
+  char tokens[] = "devices.denied-numbers=243:1\n";
+  append_config(filename, tokens);
+
+  char* container_id = (char*) "container_1498064906505_0001_01_000001";
+  char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+  char allowed_devices[] = "243:2";
+  char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices,
+                   (char*) "--allowed_devices",
+                   allowed_devices,
+                   (char*) "--container_id",
+                   container_id };
+  /* Test case 1: block 2 devices */
+  clear_cgroups_parameters_invoked();
+  int rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 7, argv);
+  ASSERT_EQ(0, rc) << "Should success.\n";
+  // Verify cgroups parameters
+  const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
+    "devices", "deny", container_id, "c 243:1 rwm"};
+  verify_param_updated_to_cgroups(8, expected_cgroups_argv);
+
+  /* Test case 2: block 2 devices but try allow devices not permitted by config*/
+  clear_cgroups_parameters_invoked();
+  // device plugin reported 0,1,2,3 totally. Allocated 1,2
+  // But c-e.cfg has device 1 denied.
+  char excluded_devices2[] = "c-243:0-rwm,c-243:3-rwm";
+  char allowed_devices2[] = "243:1,243:2";
+  char* argv1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices2,
+                   (char*) "--allowed_devices",
+                   allowed_devices2,
+                   (char*) "--container_id",
+                   container_id };
+  rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 7, argv1);
+  ASSERT_NE(0, rc) << "Should fail.\n";
+
+  clear_cgroups_parameters_invoked();
+  free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_illegal_cli_parameters) {
+  // Write config file.
+  const char *filename = TEST_ROOT "/test_illegal_cli_parameters.cfg";
+  write_and_load_devices_module_to_cfg(filename, 1);
+  char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
+  char allowed_devices[] = "243:2";
+  // Illegal container id - 1
+  char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices,
+                   (char*) "--allowed_devices",
+                   allowed_devices,
+                   (char*) "--container_id", (char*) "xxxx" };
+  int rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 7, argv);
+  ASSERT_NE(0, rc) << "Should fail.\n";
+
+  // Illegal container id - 2
+  clear_cgroups_parameters_invoked();
+  char* argv_1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
+                   excluded_devices,
+                   (char*) "--allowed_devices",
+                   allowed_devices,
+                   (char*) "--container_id", (char*) "container_1" };
+  rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 7, argv_1);
+  ASSERT_NE(0, rc) << "Should fail.\n";
+
+  // Illegal container id - 3
+  clear_cgroups_parameters_invoked();
+  char* argv_2[] = { (char*) "--module-devices",
+                     (char*) "--excluded_devices",
+                     excluded_devices };
+  rc = handle_devices_request(&mock_update_cgroups_parameters,
+     "devices", 3, argv_2);
+  ASSERT_NE(0, rc) << "Should fail.\n";
+
+  clear_cgroups_parameters_invoked();
+  free_executor_configurations();
+}
+
+TEST_F(TestDevicesModule, test_devices_module_disabled) {
+  test_devices_module_enabled_disabled(0);
+}
+
+TEST_F(TestDevicesModule, test_devices_module_enabled) {
+  test_devices_module_enabled_disabled(1);
+}
+} // namespace ContainerExecutor

+ 28 - 25
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDeviceMappingManager.java

@@ -25,6 +25,7 @@ import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
@@ -33,6 +34,8 @@ import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeS
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
@@ -74,6 +77,10 @@ public class TestDeviceMappingManager {
   private ExecutorService containerLauncher;
   private Configuration conf;
 
+  private CGroupsHandler mockCGroupsHandler;
+  private PrivilegedOperationExecutor mockPrivilegedExecutor;
+  private Context mockCtx;
+
   @Before
   public void setup() throws Exception {
     // setup resource-types.xml
@@ -89,7 +96,7 @@ public class TestDeviceMappingManager {
         isA(String.class),
         isA(ArrayList.class));
     dmm = new DeviceMappingManager(context);
-    int deviceCount = 600;
+    int deviceCount = 100;
     TreeSet<Device> r = new TreeSet<>();
     for (int i = 0; i < deviceCount; i++) {
       r.add(Device.Builder.newInstance()
@@ -117,6 +124,10 @@ public class TestDeviceMappingManager {
 
     containerLauncher =
         Executors.newFixedThreadPool(10);
+    mockCGroupsHandler = mock(CGroupsHandler.class);
+    mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
+    mockCtx = mock(NodeManager.NMContext.class);
+    when(mockCtx.getConf()).thenReturn(conf);
   }
 
   @After
@@ -134,7 +145,7 @@ public class TestDeviceMappingManager {
   @Test
   public void testAllocation()
       throws InterruptedException, ResourceHandlerException {
-    int totalContainerCount = 100;
+    int totalContainerCount = 10;
     String resourceName1 = "cmpA.com/hdwA";
     String resourceName2 = "cmp.com/cmp";
     DeviceMappingManager dmmSpy = spy(dmm);
@@ -158,11 +169,12 @@ public class TestDeviceMappingManager {
           resourceName,
           num, false);
       containerSet.get(resourceName).put(c, num);
-
+      DevicePlugin myPlugin = new MyTestPlugin();
+      DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
+          myPlugin, dmm);
       DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
-          resourceName,
-          new MyTestPlugin(), null,
-          dmmSpy, null, null);
+          resourceName, dpa,
+          dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
       Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
           dri, c, i, false));
     }
@@ -173,12 +185,11 @@ public class TestDeviceMappingManager {
     }
 
     Long endTime = System.currentTimeMillis();
-    LOG.info("Each container allocation spends roughly: {} ms",
+    LOG.info("Each container preStart spends roughly: {} ms",
         (endTime - startTime)/totalContainerCount);
     // Ensure invocation times
     verify(dmmSpy, times(totalContainerCount)).assignDevices(
         anyString(), any(Container.class));
-
     // Ensure used devices' count for each type is correct
     int totalAllocatedCount = 0;
     Map<Device, ContainerId> used1 =
@@ -198,23 +209,15 @@ public class TestDeviceMappingManager {
     for (Map.Entry<Container, Integer> entry :
         containerSet.get(resourceName1).entrySet()) {
       int containerWanted = entry.getValue();
-      int actualAllocated = 0;
-      for (ContainerId cid : used1.values()) {
-        if (cid.equals(entry.getKey().getContainerId())) {
-          actualAllocated++;
-        }
-      }
+      int actualAllocated = dmm.getAllocatedDevices(resourceName1,
+          entry.getKey().getContainerId()).size();
       Assert.assertEquals(containerWanted, actualAllocated);
     }
     for (Map.Entry<Container, Integer> entry :
         containerSet.get(resourceName2).entrySet()) {
       int containerWanted = entry.getValue();
-      int actualAllocated = 0;
-      for (ContainerId cid : used2.values()) {
-        if (cid.equals(entry.getKey().getContainerId())) {
-          actualAllocated++;
-        }
-      }
+      int actualAllocated = dmm.getAllocatedDevices(resourceName2,
+          entry.getKey().getContainerId()).size();
       Assert.assertEquals(containerWanted, actualAllocated);
     }
   }
@@ -248,11 +251,12 @@ public class TestDeviceMappingManager {
           resourceName,
           num, false);
       containerSet.get(resourceName).put(c, num);
-
+      DevicePlugin myPlugin = new MyTestPlugin();
+      DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
+          myPlugin, dmm);
       DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
-          resourceName,
-          new MyTestPlugin(), null,
-          dmmSpy, null, null);
+          resourceName, dpa,
+          dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
       Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
           dri, c, i, true));
     }
@@ -262,7 +266,6 @@ public class TestDeviceMappingManager {
       LOG.info("Wait for the threads to finish");
     }
 
-
     // Ensure invocation times
     verify(dmmSpy, times(totalContainerCount)).assignDevices(
         anyString(), any(Container.class));

+ 383 - 29
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/TestDevicePluginAdapter.java

@@ -18,7 +18,6 @@
 
 package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
 
-import org.apache.hadoop.service.ServiceOperations;
 
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
@@ -34,12 +33,20 @@ import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePluginManager;
 import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
 import org.apache.hadoop.yarn.server.nodemanager.recovery.NMMemoryStateStoreService;
@@ -51,6 +58,7 @@ import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -60,15 +68,21 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.TreeSet;
 import java.util.concurrent.ConcurrentHashMap;
 
+
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.isA;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.times;
@@ -89,7 +103,6 @@ public class TestDevicePluginAdapter {
   private String tempResourceTypesFile;
   private CGroupsHandler mockCGroupsHandler;
   private PrivilegedOperationExecutor mockPrivilegedExecutor;
-  private NodeManager nm;
 
   @Before
   public void setup() throws Exception {
@@ -110,13 +123,6 @@ public class TestDevicePluginAdapter {
     if (dest.exists()) {
       dest.delete();
     }
-    if (nm != null) {
-      try {
-        ServiceOperations.stop(nm);
-      } catch (Throwable t) {
-        // ignore
-      }
-    }
   }
 
 
@@ -130,16 +136,14 @@ public class TestDevicePluginAdapter {
     NodeManager.NMContext context = mock(NodeManager.NMContext.class);
     NMStateStoreService storeService = mock(NMStateStoreService.class);
     when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
     doNothing().when(storeService).storeAssignedResources(isA(Container.class),
         isA(String.class),
         isA(ArrayList.class));
-
     // Init scheduler manager
     DeviceMappingManager dmm = new DeviceMappingManager(context);
-
     ResourcePluginManager rpm = mock(ResourcePluginManager.class);
     when(rpm.getDeviceMappingManager()).thenReturn(dmm);
-
     // Init an plugin
     MyPlugin plugin = new MyPlugin();
     MyPlugin spyPlugin = spy(plugin);
@@ -150,14 +154,19 @@ public class TestDevicePluginAdapter {
         spyPlugin, dmm);
     // Bootstrap, adding device
     adapter.initialize(context);
-    adapter.createResourceHandler(context,
-        mockCGroupsHandler, mockPrivilegedExecutor);
+    // Use mock shell when create resourceHandler
+    ShellWrapper mockShellWrapper = mock(ShellWrapper.class);
+    when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+    when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
+    DeviceResourceHandlerImpl drhl = new DeviceResourceHandlerImpl(resourceName,
+        adapter, dmm, mockCGroupsHandler, mockPrivilegedExecutor, context,
+        mockShellWrapper);
+    adapter.setDeviceResourceHandler(drhl);
     adapter.getDeviceResourceHandler().bootstrap(conf);
     int size = dmm.getAvailableDevices(resourceName);
     Assert.assertEquals(3, size);
-
-    // A container c1 requests 1 device
-    Container c1 = mockContainerWithDeviceRequest(0,
+    // Case 1. A container c1 requests 1 device
+    Container c1 = mockContainerWithDeviceRequest(1,
         resourceName,
         1, false);
     // preStart
@@ -169,19 +178,33 @@ public class TestDevicePluginAdapter {
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(1,
+        dmm.getAllocatedDevices(resourceName, c1.getContainerId()).size());
+    verify(mockShellWrapper, times(2)).getDeviceFileType(anyString());
+    // check device cgroup create operation
+    checkCgroupOperation(c1.getContainerId().toString(), 1,
+        "c-256:1-rwm,c-256:2-rwm", "256:0");
     // postComplete
-    adapter.getDeviceResourceHandler().postComplete(getContainerId(0));
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
     Assert.assertEquals(3,
         dmm.getAvailableDevices(resourceName));
     Assert.assertEquals(0,
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
-
-    // A container c2 requests 3 device
-    Container c2 = mockContainerWithDeviceRequest(1,
+    // check cgroup delete operation
+    verify(mockCGroupsHandler).deleteCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c1.getContainerId().toString());
+    // Case 2. A container c2 requests 3 device
+    Container c2 = mockContainerWithDeviceRequest(2,
         resourceName,
         3, false);
+    reset(mockShellWrapper);
+    reset(mockCGroupsHandler);
+    reset(mockPrivilegedExecutor);
+    when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+    when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
     // preStart
     adapter.getDeviceResourceHandler().preStart(c2);
     // check book keeping
@@ -191,19 +214,37 @@ public class TestDevicePluginAdapter {
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllocatedDevices(resourceName, c2.getContainerId()).size());
+    verify(mockShellWrapper, times(0)).getDeviceFileType(anyString());
+    // check device cgroup create operation
+    verify(mockCGroupsHandler).createCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c2.getContainerId().toString());
+    // check device cgroup update operation
+    checkCgroupOperation(c2.getContainerId().toString(), 1,
+        null, "256:0,256:1,256:2");
     // postComplete
-    adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(2));
     Assert.assertEquals(3,
         dmm.getAvailableDevices(resourceName));
     Assert.assertEquals(0,
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
-
-    // A container c3 request 0 device
-    Container c3 = mockContainerWithDeviceRequest(1,
+    // check cgroup delete operation
+    verify(mockCGroupsHandler).deleteCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c2.getContainerId().toString());
+    // Case 3. A container c3 request 0 device
+    Container c3 = mockContainerWithDeviceRequest(3,
         resourceName,
         0, false);
+    reset(mockShellWrapper);
+    reset(mockCGroupsHandler);
+    reset(mockPrivilegedExecutor);
+    when(mockShellWrapper.existFile(anyString())).thenReturn(true);
+    when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
     // preStart
     adapter.getDeviceResourceHandler().preStart(c3);
     // check book keeping
@@ -213,14 +254,57 @@ public class TestDevicePluginAdapter {
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
+    verify(mockShellWrapper, times(3)).getDeviceFileType(anyString());
+    // check device cgroup create operation
+    verify(mockCGroupsHandler).createCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c3.getContainerId().toString());
+    // check device cgroup update operation
+    checkCgroupOperation(c3.getContainerId().toString(), 1,
+        "c-256:0-rwm,c-256:1-rwm,c-256:2-rwm", null);
     // postComplete
-    adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(3));
     Assert.assertEquals(3,
         dmm.getAvailableDevices(resourceName));
     Assert.assertEquals(0,
         dmm.getAllUsedDevices().get(resourceName).size());
     Assert.assertEquals(3,
         dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(0,
+        dmm.getAllocatedDevices(resourceName, c3.getContainerId()).size());
+    // check cgroup delete operation
+    verify(mockCGroupsHandler).deleteCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c3.getContainerId().toString());
+  }
+
+  private void checkCgroupOperation(String cId,
+      int invokeTimesOfPrivilegedExecutor,
+      String excludedParam, String allowedParam)
+      throws PrivilegedOperationException, ResourceHandlerException {
+    verify(mockCGroupsHandler).createCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        cId);
+    // check device cgroup update operation
+    ArgumentCaptor<PrivilegedOperation> args =
+        ArgumentCaptor.forClass(PrivilegedOperation.class);
+    verify(mockPrivilegedExecutor, times(invokeTimesOfPrivilegedExecutor))
+        .executePrivilegedOperation(args.capture(), eq(true));
+    Assert.assertEquals(PrivilegedOperation.OperationType.DEVICE,
+        args.getValue().getOperationType());
+    List<String> expectedArgs = new ArrayList<>();
+    expectedArgs.add(DeviceResourceHandlerImpl.CONTAINER_ID_CLI_OPTION);
+    expectedArgs.add(cId);
+    if (excludedParam != null && !excludedParam.isEmpty()) {
+      expectedArgs.add(DeviceResourceHandlerImpl.EXCLUDED_DEVICES_CLI_OPTION);
+      expectedArgs.add(excludedParam);
+    }
+    if (allowedParam != null && !allowedParam.isEmpty()) {
+      expectedArgs.add(DeviceResourceHandlerImpl.ALLOWED_DEVICES_CLI_OPTION);
+      expectedArgs.add(allowedParam);
+    }
+    Assert.assertArrayEquals(expectedArgs.toArray(),
+        args.getValue().getArguments().toArray());
   }
 
   @Test
@@ -251,6 +335,7 @@ public class TestDevicePluginAdapter {
     NMStateStoreService realStoreService = new NMMemoryStateStoreService();
     NMStateStoreService storeService = spy(realStoreService);
     when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
     doNothing().when(storeService).storeAssignedResources(isA(Container.class),
         isA(String.class),
         isA(ArrayList.class));
@@ -395,6 +480,7 @@ public class TestDevicePluginAdapter {
     NodeManager.NMContext context = mock(NodeManager.NMContext.class);
     NMStateStoreService realStoreService = new NMMemoryStateStoreService();
     NMStateStoreService storeService = spy(realStoreService);
+    when(context.getConf()).thenReturn(this.conf);
     when(context.getNMStateStore()).thenReturn(storeService);
     doThrow(new IOException("Exception ...")).when(storeService)
         .storeAssignedResources(isA(Container.class),
@@ -448,6 +534,7 @@ public class TestDevicePluginAdapter {
     NodeManager.NMContext context = mock(NodeManager.NMContext.class);
     NMStateStoreService storeService = mock(NMStateStoreService.class);
     when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
     doNothing().when(storeService).storeAssignedResources(isA(Container.class),
         isA(String.class),
         isA(ArrayList.class));
@@ -526,6 +613,7 @@ public class TestDevicePluginAdapter {
     NodeManager.NMContext context = mock(NodeManager.NMContext.class);
     NMStateStoreService storeService = mock(NMStateStoreService.class);
     when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
     doNothing().when(storeService).storeAssignedResources(isA(Container.class),
         isA(String.class),
         isA(ArrayList.class));
@@ -584,6 +672,206 @@ public class TestDevicePluginAdapter {
     Assert.assertEquals(3, response.getTotalDevices().size());
   }
 
+  /**
+   * Test a container run command update when using Docker runtime.
+   * And the device plugin it uses is like Nvidia Docker v1.
+   * */
+  @Test
+  public void testDeviceResourceDockerRuntimePlugin1() throws Exception {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService storeService = mock(NMStateStoreService.class);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    DeviceMappingManager spyDmm = spy(dmm);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
+    // Init a plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, spyDmm);
+    adapter.initialize(context);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    // Case 1. A container request Docker runtime and 1 device
+    Container c1 = mockContainerWithDeviceRequest(1, resourceName, 1, true);
+    // generate spec based on v1
+    spyPlugin.setDevicePluginVersion("v1");
+    // preStart will do allocation
+    adapter.getDeviceResourceHandler().preStart(c1);
+    Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
+        c1.getContainerId());
+    reset(spyDmm);
+    // c1 is requesting docker runtime.
+    // it will create parent cgroup but no cgroups update operation needed.
+    // check device cgroup create operation
+    verify(mockCGroupsHandler).createCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c1.getContainerId().toString());
+    // ensure no cgroups update operation
+    verify(mockPrivilegedExecutor, times(0))
+        .executePrivilegedOperation(
+            any(PrivilegedOperation.class), anyBoolean());
+    DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
+    // When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
+    // First to create volume
+    DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
+    // ensure that allocation is get once from device mapping manager
+    verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+    // ensure that plugin's onDeviceAllocated is invoked
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DEFAULT);
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    Assert.assertEquals("nvidia-docker", dvc.getDriverName());
+    Assert.assertEquals("create", dvc.getSubCommand());
+    Assert.assertEquals("nvidia_driver_352.68", dvc.getVolumeName());
+
+    // then the DockerLinuxContainerRuntime will update docker run command
+    DockerRunCommand drc =
+        new DockerRunCommand(c1.getContainerId().toString(), "user",
+            "image/tensorflow");
+    // reset to avoid count times in above invocation
+    reset(spyPlugin);
+    reset(spyDmm);
+    // Second, update the run command.
+    dcp.updateDockerRunCommand(drc, c1);
+    // The spec is already generated in getCreateDockerVolumeCommand
+    // and there should be a cache hit for DeviceRuntime spec.
+    verify(spyPlugin, times(0)).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    // ensure that allocation is get from cache instead of device mapping
+    // manager
+    verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
+        c1.getContainerId());
+    String runStr = drc.toString();
+    Assert.assertTrue(
+        runStr.contains("nvidia_driver_352.68:/usr/local/nvidia:ro"));
+    Assert.assertTrue(runStr.contains("/dev/hdwA0:/dev/hdwA0"));
+    // Third, cleanup in getCleanupDockerVolumesCommand
+    dcp.getCleanupDockerVolumesCommand(c1);
+    // Ensure device plugin's onDeviceReleased is invoked
+    verify(spyPlugin).onDevicesReleased(allocatedDevice);
+    // If we run the c1 again. No cache will be used for allocation and spec
+    dcp.getCreateDockerVolumeCommand(c1);
+    verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+  }
+
+  /**
+   * Test a container run command update when using Docker runtime.
+   * And the device plugin it uses is like Nvidia Docker v2.
+   * */
+  @Test
+  public void testDeviceResourceDockerRuntimePlugin2() throws Exception {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService storeService = mock(NMStateStoreService.class);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    when(context.getConf()).thenReturn(this.conf);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    DeviceMappingManager spyDmm = spy(dmm);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
+    // Init a plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, spyDmm);
+    adapter.initialize(context);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    // Case 1. A container request Docker runtime and 1 device
+    Container c1 = mockContainerWithDeviceRequest(1, resourceName, 2, true);
+    // generate spec based on v2
+    spyPlugin.setDevicePluginVersion("v2");
+    // preStart will do allocation
+    adapter.getDeviceResourceHandler().preStart(c1);
+    Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
+        c1.getContainerId());
+    reset(spyDmm);
+    // c1 is requesting docker runtime.
+    // it will create parent cgroup but no cgroups update operation needed.
+    // check device cgroup create operation
+    verify(mockCGroupsHandler).createCGroup(
+        CGroupsHandler.CGroupController.DEVICES,
+        c1.getContainerId().toString());
+    // ensure no cgroups update operation
+    verify(mockPrivilegedExecutor, times(0))
+        .executePrivilegedOperation(
+            any(PrivilegedOperation.class), anyBoolean());
+    DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
+    // When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
+    // First to create volume
+    DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
+    // ensure that allocation is get once from device mapping manager
+    verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+    // ensure that plugin's onDeviceAllocated is invoked
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DEFAULT);
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    // No volume creation request
+    Assert.assertNull(dvc);
+
+    // then the DockerLinuxContainerRuntime will update docker run command
+    DockerRunCommand drc =
+        new DockerRunCommand(c1.getContainerId().toString(), "user",
+            "image/tensorflow");
+    // reset to avoid count times in above invocation
+    reset(spyPlugin);
+    reset(spyDmm);
+    // Second, update the run command.
+    dcp.updateDockerRunCommand(drc, c1);
+    // The spec is already generated in getCreateDockerVolumeCommand
+    // and there should be a cache hit for DeviceRuntime spec.
+    verify(spyPlugin, times(0)).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    // ensure that allocation is get once from device mapping manager
+    verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
+        c1.getContainerId());
+    Assert.assertEquals("0,1", drc.getEnv().get("NVIDIA_VISIBLE_DEVICES"));
+    Assert.assertTrue(drc.toString().contains("runtime=nvidia"));
+    // Third, cleanup in getCleanupDockerVolumesCommand
+    dcp.getCleanupDockerVolumesCommand(c1);
+    // Ensure device plugin's onDeviceReleased is invoked
+    verify(spyPlugin).onDevicesReleased(allocatedDevice);
+    // If we run the c1 again. No cache will be used for allocation and spec
+    dcp.getCreateDockerVolumeCommand(c1);
+    verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
+    verify(spyPlugin).onDevicesAllocated(
+        allocatedDevice,
+        YarnRuntimeType.RUNTIME_DOCKER);
+  }
+
   private static ContainerId getContainerId(int id) {
     return ContainerId.newContainerId(ApplicationAttemptId
         .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
@@ -591,6 +879,15 @@ public class TestDevicePluginAdapter {
 
   private class MyPlugin implements DevicePlugin, DevicePluginScheduler {
     private final static String RESOURCE_NAME = "cmpA.com/hdwA";
+
+    // v1 means the vendor uses the similar way of Nvidia Docker v1
+    // v2 means the vendor user the similar way of Nvidia Docker v2
+    private String devicePluginVersion = "v2";
+
+    public void setDevicePluginVersion(String version) {
+      devicePluginVersion = version;
+    }
+
     @Override
     public DeviceRegisterRequest getRegisterRequestInfo() {
       return DeviceRegisterRequest.Builder.newInstance()
@@ -613,7 +910,7 @@ public class TestDevicePluginAdapter {
           .setId(1)
           .setDevPath("/dev/hdwA1")
           .setMajorNumber(256)
-          .setMinorNumber(0)
+          .setMinorNumber(1)
           .setBusID("0000:80:01.0")
           .setHealthy(true)
           .build());
@@ -621,7 +918,7 @@ public class TestDevicePluginAdapter {
           .setId(2)
           .setDevPath("/dev/hdwA2")
           .setMajorNumber(256)
-          .setMinorNumber(0)
+          .setMinorNumber(2)
           .setBusID("0000:80:02.0")
           .setHealthy(true)
           .build());
@@ -631,12 +928,69 @@ public class TestDevicePluginAdapter {
     @Override
     public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
         YarnRuntimeType yarnRuntime) throws Exception {
+      if (yarnRuntime == YarnRuntimeType.RUNTIME_DEFAULT) {
+        return null;
+      }
+      if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
+        return generateSpec(devicePluginVersion, allocatedDevices);
+      }
       return null;
     }
 
+    private DeviceRuntimeSpec generateSpec(String version,
+        Set<Device> allocatedDevices) {
+      DeviceRuntimeSpec.Builder builder =
+          DeviceRuntimeSpec.Builder.newInstance();
+      if (version.equals("v1")) {
+        // Nvidia v1 examples like below. These info is get from Nvidia v1
+        // RESTful.
+        // --device=/dev/nvidiactl --device=/dev/nvidia-uvm
+        // --device=/dev/nvidia0
+        // --volume-driver=nvidia-docker
+        // --volume=nvidia_driver_352.68:/usr/local/nvidia:ro
+        String volumeDriverName = "nvidia-docker";
+        String volumeToBeCreated = "nvidia_driver_352.68";
+        String volumePathInContainer = "/usr/local/nvidia";
+        // describe volumes to be created and mounted
+        builder.addVolumeSpec(
+                VolumeSpec.Builder.newInstance()
+                    .setVolumeDriver(volumeDriverName)
+                    .setVolumeName(volumeToBeCreated)
+                    .setVolumeOperation(VolumeSpec.CREATE).build())
+            .addMountVolumeSpec(
+                MountVolumeSpec.Builder.newInstance()
+                    .setHostPath(volumeToBeCreated)
+                    .setMountPath(volumePathInContainer)
+                    .setReadOnly(true).build());
+        // describe devices to be mounted
+        for (Device device : allocatedDevices) {
+          builder.addMountDeviceSpec(
+              MountDeviceSpec.Builder.newInstance()
+                  .setDevicePathInHost(device.getDevPath())
+                  .setDevicePathInContainer(device.getDevPath())
+                  .setDevicePermission(MountDeviceSpec.RW).build());
+        }
+      }
+      if (version.equals("v2")) {
+        String nvidiaRuntime = "nvidia";
+        String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
+        StringBuffer gpuMinorNumbersSB = new StringBuffer();
+        for (Device device : allocatedDevices) {
+          gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
+        }
+        String minorNumbers = gpuMinorNumbersSB.toString();
+        // set runtime and environment variable is enough for
+        // plugin like Nvidia Docker v2
+        builder.addEnv(nvidiaVisibleDevices,
+            minorNumbers.substring(0, minorNumbers.length() - 1))
+            .setContainerRuntime(nvidiaRuntime);
+      }
+      return builder.build();
+    }
+
     @Override
     public void onDevicesReleased(Set<Device> releasedDevices) {
-
+      // nothing to do
     }
 
     @Override

+ 108 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/nvidia/com/TestNvidiaGpuPlugin.java

@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.nvidia.com;
+
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPluginForRuntimeV2;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Set;
+import java.util.TreeSet;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Test case for Nvidia GPU device plugin.
+ * */
+public class TestNvidiaGpuPlugin {
+
+  @Test
+  public void testGetNvidiaDevices() throws Exception {
+    NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
+        mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
+    String deviceInfoShellOutput =
+        "0, 00000000:04:00.0\n" +
+        "1, 00000000:82:00.0";
+    String majorMinorNumber0 = "c3:0";
+    String majorMinorNumber1 = "c3:1";
+    when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
+    when(mockShell.getMajorMinorInfo("nvidia0"))
+        .thenReturn(majorMinorNumber0);
+    when(mockShell.getMajorMinorInfo("nvidia1"))
+        .thenReturn(majorMinorNumber1);
+    NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
+    plugin.setShellExecutor(mockShell);
+    plugin.setPathOfGpuBinary("/fake/nvidia-smi");
+
+    Set<Device> expectedDevices = new TreeSet<>();
+    expectedDevices.add(Device.Builder.newInstance()
+        .setId(0).setHealthy(true)
+        .setBusID("00000000:04:00.0")
+        .setDevPath("/dev/nvidia0")
+        .setMajorNumber(195)
+        .setMinorNumber(0).build());
+    expectedDevices.add(Device.Builder.newInstance()
+        .setId(1).setHealthy(true)
+        .setBusID("00000000:82:00.0")
+        .setDevPath("/dev/nvidia1")
+        .setMajorNumber(195)
+        .setMinorNumber(1).build());
+    Set<Device> devices = plugin.getDevices();
+    Assert.assertEquals(expectedDevices, devices);
+  }
+
+  @Test
+  public void testOnDeviceAllocated() throws Exception {
+    NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
+    Set<Device> allocatedDevices = new TreeSet<>();
+
+    DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
+        YarnRuntimeType.RUNTIME_DEFAULT);
+    Assert.assertNull(spec);
+
+    // allocate one device
+    allocatedDevices.add(Device.Builder.newInstance()
+        .setId(0).setHealthy(true)
+        .setBusID("00000000:04:00.0")
+        .setDevPath("/dev/nvidia0")
+        .setMajorNumber(195)
+        .setMinorNumber(0).build());
+    spec = plugin.onDevicesAllocated(allocatedDevices,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    Assert.assertEquals("nvidia", spec.getContainerRuntime());
+    Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
+
+    // two device allowed
+    allocatedDevices.add(Device.Builder.newInstance()
+        .setId(0).setHealthy(true)
+        .setBusID("00000000:82:00.0")
+        .setDevPath("/dev/nvidia1")
+        .setMajorNumber(195)
+        .setMinorNumber(1).build());
+    spec = plugin.onDevicesAllocated(allocatedDevices,
+        YarnRuntimeType.RUNTIME_DOCKER);
+    Assert.assertEquals("nvidia", spec.getContainerRuntime());
+    Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
+
+  }
+}