Browse Source

YARN-644: Basic null check is not performed on passed in arguments before using them in ContainerManagerImpl.startContainer

(cherry picked from commit bcf2890502fbd11dd394048fe30d67c92aeec4fa)
Robert (Bobby) Evans 10 năm trước cách đây
mục cha
commit
28e0593b96

+ 3 - 0
hadoop-yarn-project/CHANGES.txt

@@ -65,6 +65,9 @@ Release 2.8.0 - UNRELEASED
 
   IMPROVEMENTS
 
+    YARN-644. Basic null check is not performed on passed in arguments before
+    using them in ContainerManagerImpl.startContainer (Varun Saxena via bobby)
+
     YARN-1880. Cleanup TestApplicationClientProtocolOnHA
     (ozawa via harsh)
 

+ 26 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/ContainerManagerImpl.java

@@ -151,6 +151,10 @@ public class ContainerManagerImpl extends CompositeService implements
 
   private static final Log LOG = LogFactory.getLog(ContainerManagerImpl.class);
 
+  static final String INVALID_NMTOKEN_MSG = "Invalid NMToken";
+  static final String INVALID_CONTAINERTOKEN_MSG =
+      "Invalid ContainerToken";
+
   final Context context;
   private final ContainersMonitor containersMonitor;
   private Server server;
@@ -641,6 +645,9 @@ public class ContainerManagerImpl extends CompositeService implements
 
   protected void authorizeUser(UserGroupInformation remoteUgi,
       NMTokenIdentifier nmTokenIdentifier) throws YarnException {
+    if (nmTokenIdentifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_NMTOKEN_MSG);
+    }
     if (!remoteUgi.getUserName().equals(
       nmTokenIdentifier.getApplicationAttemptId().toString())) {
       throw RPCUtil.getRemoteException("Expected applicationAttemptId: "
@@ -658,7 +665,12 @@ public class ContainerManagerImpl extends CompositeService implements
   @VisibleForTesting
   protected void authorizeStartRequest(NMTokenIdentifier nmTokenIdentifier,
       ContainerTokenIdentifier containerTokenIdentifier) throws YarnException {
-
+    if (nmTokenIdentifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_NMTOKEN_MSG);
+    }
+    if (containerTokenIdentifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_CONTAINERTOKEN_MSG);
+    }
     ContainerId containerId = containerTokenIdentifier.getContainerID();
     String containerIDStr = containerId.toString();
     boolean unauthorized = false;
@@ -717,6 +729,10 @@ public class ContainerManagerImpl extends CompositeService implements
     for (StartContainerRequest request : requests.getStartContainerRequests()) {
       ContainerId containerId = null;
       try {
+        if (request.getContainerToken() == null ||
+            request.getContainerToken().getIdentifier() == null) {
+          throw new IOException(INVALID_CONTAINERTOKEN_MSG);
+        }
         ContainerTokenIdentifier containerTokenIdentifier =
             BuilderUtils.newContainerTokenIdentifier(request.getContainerToken());
         verifyAndGetContainerTokenIdentifier(request.getContainerToken(),
@@ -946,6 +962,9 @@ public class ContainerManagerImpl extends CompositeService implements
         new HashMap<ContainerId, SerializedException>();
     UserGroupInformation remoteUgi = getRemoteUgi();
     NMTokenIdentifier identifier = selectNMTokenIdentifier(remoteUgi);
+    if (identifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_NMTOKEN_MSG);
+    }
     for (ContainerId id : requests.getContainerIds()) {
       try {
         stopContainerInternal(identifier, id);
@@ -1001,6 +1020,9 @@ public class ContainerManagerImpl extends CompositeService implements
         new HashMap<ContainerId, SerializedException>();
     UserGroupInformation remoteUgi = getRemoteUgi();
     NMTokenIdentifier identifier = selectNMTokenIdentifier(remoteUgi);
+    if (identifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_NMTOKEN_MSG);
+    }
     for (ContainerId id : request.getContainerIds()) {
       try {
         ContainerStatus status = getContainerStatusInternal(id, identifier);
@@ -1041,6 +1063,9 @@ public class ContainerManagerImpl extends CompositeService implements
   protected void authorizeGetAndStopContainerRequest(ContainerId containerId,
       Container container, boolean stopRequest, NMTokenIdentifier identifier)
       throws YarnException {
+    if (identifier == null) {
+      throw RPCUtil.getRemoteException(INVALID_NMTOKEN_MSG);
+    }
     /*
      * For get/stop container status; we need to verify that 1) User (NMToken)
      * application attempt only has started container. 2) Requested containerId

+ 6 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/BaseContainerManagerTest.java

@@ -230,6 +230,12 @@ public abstract class BaseContainerManagerTest {
             ByteBuffer.wrap("AuxServiceMetaData2".getBytes()));
         return serviceData;
       }
+
+      @Override
+      protected NMTokenIdentifier selectNMTokenIdentifier(
+          UserGroupInformation remoteUgi) {
+        return new NMTokenIdentifier();
+      }
     };
   }
 

+ 87 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManager.java

@@ -45,6 +45,9 @@ import org.apache.hadoop.yarn.api.protocolrecords.StartContainersRequest;
 import org.apache.hadoop.yarn.api.protocolrecords.StartContainersResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.StopContainersRequest;
 import org.apache.hadoop.yarn.api.protocolrecords.StopContainersResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.GetContainerStatusesRequestPBImpl;
+import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.StartContainersRequestPBImpl;
+import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.StopContainersRequestPBImpl;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
@@ -83,6 +86,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.Mockito;
 
 public class TestContainerManager extends BaseContainerManagerTest {
 
@@ -792,6 +796,89 @@ public class TestContainerManager extends BaseContainerManagerTest {
         .contains("The auxService:" + serviceName + " does not exist"));
   }
 
+  /* Test added to verify fix in YARN-644 */
+  @Test
+  public void testNullTokens() throws Exception {
+    ContainerManagerImpl cMgrImpl =
+        new ContainerManagerImpl(context, exec, delSrvc, nodeStatusUpdater,
+        metrics, new ApplicationACLsManager(conf), dirsHandler);
+    String strExceptionMsg = "";
+    try {
+      cMgrImpl.authorizeStartRequest(null, new ContainerTokenIdentifier());
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_NMTOKEN_MSG);
+
+    strExceptionMsg = "";
+    try {
+      cMgrImpl.authorizeStartRequest(new NMTokenIdentifier(), null);
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_CONTAINERTOKEN_MSG);
+
+    strExceptionMsg = "";
+    try {
+      cMgrImpl.authorizeGetAndStopContainerRequest(null, null, true, null);
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_NMTOKEN_MSG);
+
+    strExceptionMsg = "";
+    try {
+      cMgrImpl.authorizeUser(null, null);
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_NMTOKEN_MSG);
+
+    ContainerManagerImpl spyContainerMgr = Mockito.spy(cMgrImpl);
+    UserGroupInformation ugInfo = UserGroupInformation.createRemoteUser("a");
+    Mockito.when(spyContainerMgr.getRemoteUgi()).thenReturn(ugInfo);
+    Mockito.when(spyContainerMgr.
+        selectNMTokenIdentifier(ugInfo)).thenReturn(null);
+
+    strExceptionMsg = "";
+    try {
+      spyContainerMgr.stopContainers(new StopContainersRequestPBImpl());
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_NMTOKEN_MSG);
+
+    strExceptionMsg = "";
+    try {
+      spyContainerMgr.getContainerStatuses(
+          new GetContainerStatusesRequestPBImpl());
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_NMTOKEN_MSG);
+
+    Mockito.doNothing().when(spyContainerMgr).authorizeUser(ugInfo, null);
+    List<StartContainerRequest> reqList
+        = new ArrayList<StartContainerRequest>();
+    reqList.add(StartContainerRequest.newInstance(null, null));
+    StartContainersRequest reqs = new StartContainersRequestPBImpl();
+    reqs.setStartContainerRequests(reqList);
+    strExceptionMsg = "";
+    try {
+      spyContainerMgr.startContainers(reqs);
+    } catch(YarnException ye) {
+      strExceptionMsg = ye.getCause().getMessage();
+    }
+    Assert.assertEquals(strExceptionMsg,
+        ContainerManagerImpl.INVALID_CONTAINERTOKEN_MSG);
+  }
+
   public static Token createContainerToken(ContainerId cId, long rmIdentifier,
       NodeId nodeId, String user,
       NMContainerTokenSecretManager containerTokenSecretManager)