浏览代码

AMBARI-9334 - Ambari StageDAO.findByCommandStatuses causes Postgress HIGH CPU (jonathanhurley)

Jonathan Hurley 10 年之前
父节点
当前提交
3beda060b3

+ 22 - 13
ambari-server/src/main/java/org/apache/ambari/server/actionmanager/ActionDBAccessor.java

@@ -17,15 +17,15 @@
  */
  */
 package org.apache.ambari.server.actionmanager;
 package org.apache.ambari.server.actionmanager;
 
 
-import com.google.inject.persist.Transactional;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.agent.CommandReport;
 import org.apache.ambari.server.agent.CommandReport;
 import org.apache.ambari.server.agent.ExecutionCommand;
 import org.apache.ambari.server.agent.ExecutionCommand;
 
 
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
+import com.google.inject.persist.Transactional;
 
 
 public interface ActionDBAccessor {
 public interface ActionDBAccessor {
 
 
@@ -57,11 +57,25 @@ public interface ActionDBAccessor {
   public void timeoutHostRole(String host, long requestId, long stageId, String role);
   public void timeoutHostRole(String host, long requestId, long stageId, String role);
 
 
   /**
   /**
-   * Returns all the pending stages, including queued and not-queued.
-   * A stage is considered in progress if it is in progress for any host.
+   * Returns all the pending stages, including queued and not-queued. A stage is
+   * considered in progress if it is in progress for any host.
+   * <p/>
+   * The results will be sorted by request ID and then stage ID making this call
+   * expensive in some scenarios. Use {@link #getCommandsInProgressCount()} in
+   * order to determine if there are stages that are in progress before getting
+   * the stages from this method.
+   *
+   * @see HostRoleStatus#IN_PROGRESS_STATUSES
    */
    */
   public List<Stage> getStagesInProgress();
   public List<Stage> getStagesInProgress();
 
 
+  /**
+   * Gets the number of commands in progress.
+   *
+   * @return the number of commands in progress.
+   */
+  public int getCommandsInProgressCount();
+
   /**
   /**
    * Persists all tasks for a given request
    * Persists all tasks for a given request
    * @param request request object
    * @param request request object
@@ -152,17 +166,12 @@ public interface ActionDBAccessor {
    * Get a List of host role commands where the host, role and status are as specified
    * Get a List of host role commands where the host, role and status are as specified
    */
    */
   public List<HostRoleCommand> getTasksByHostRoleAndStatus(String hostname, String role, HostRoleStatus status);
   public List<HostRoleCommand> getTasksByHostRoleAndStatus(String hostname, String role, HostRoleStatus status);
-  
+
   /**
   /**
    * Get a List of host role commands where the role and status are as specified
    * Get a List of host role commands where the role and status are as specified
    */
    */
   public List<HostRoleCommand> getTasksByRoleAndStatus(String role, HostRoleStatus status);
   public List<HostRoleCommand> getTasksByRoleAndStatus(String role, HostRoleStatus status);
 
 
-  /**
-   * Get all stages that contain tasks with specified host role statuses
-   */
-  public List<Stage> getStagesByHostRoleStatus(Set<HostRoleStatus> statuses);
-
   /**
   /**
    * Gets the host role command corresponding to the task id
    * Gets the host role command corresponding to the task id
    */
    */

+ 52 - 35
ambari-server/src/main/java/org/apache/ambari/server/actionmanager/ActionDBAccessorImpl.java

@@ -17,12 +17,18 @@
  */
  */
 package org.apache.ambari.server.actionmanager;
 package org.apache.ambari.server.actionmanager;
 
 
-import com.google.common.cache.Cache;
-import com.google.common.cache.CacheBuilder;
-import com.google.inject.Inject;
-import com.google.inject.Singleton;
-import com.google.inject.name.Named;
-import com.google.inject.persist.Transactional;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.agent.CommandReport;
 import org.apache.ambari.server.agent.CommandReport;
 import org.apache.ambari.server.agent.ExecutionCommand;
 import org.apache.ambari.server.agent.ExecutionCommand;
@@ -48,49 +54,55 @@ import org.apache.ambari.server.utils.StageUtils;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.TimeUnit;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.inject.Inject;
+import com.google.inject.Singleton;
+import com.google.inject.name.Named;
+import com.google.inject.persist.Transactional;
 
 
 @Singleton
 @Singleton
 public class ActionDBAccessorImpl implements ActionDBAccessor {
 public class ActionDBAccessorImpl implements ActionDBAccessor {
   private static final Logger LOG = LoggerFactory.getLogger(ActionDBAccessorImpl.class);
   private static final Logger LOG = LoggerFactory.getLogger(ActionDBAccessorImpl.class);
+
   private long requestId;
   private long requestId;
+
   @Inject
   @Inject
   ClusterDAO clusterDAO;
   ClusterDAO clusterDAO;
+
   @Inject
   @Inject
   HostDAO hostDAO;
   HostDAO hostDAO;
+
   @Inject
   @Inject
   RequestDAO requestDAO;
   RequestDAO requestDAO;
+
   @Inject
   @Inject
   StageDAO stageDAO;
   StageDAO stageDAO;
+
   @Inject
   @Inject
   HostRoleCommandDAO hostRoleCommandDAO;
   HostRoleCommandDAO hostRoleCommandDAO;
+
   @Inject
   @Inject
   ExecutionCommandDAO executionCommandDAO;
   ExecutionCommandDAO executionCommandDAO;
+
   @Inject
   @Inject
   RoleSuccessCriteriaDAO roleSuccessCriteriaDAO;
   RoleSuccessCriteriaDAO roleSuccessCriteriaDAO;
+
   @Inject
   @Inject
   StageFactory stageFactory;
   StageFactory stageFactory;
+
   @Inject
   @Inject
   RequestFactory requestFactory;
   RequestFactory requestFactory;
+
   @Inject
   @Inject
   HostRoleCommandFactory hostRoleCommandFactory;
   HostRoleCommandFactory hostRoleCommandFactory;
+
   @Inject
   @Inject
   Clusters clusters;
   Clusters clusters;
+
   @Inject
   @Inject
   RequestScheduleDAO requestScheduleDAO;
   RequestScheduleDAO requestScheduleDAO;
 
 
-
-
   private Cache<Long, HostRoleCommand> hostRoleCommandCache;
   private Cache<Long, HostRoleCommand> hostRoleCommandCache;
   private long cacheLimit; //may be exceeded to store tasks from one request
   private long cacheLimit; //may be exceeded to store tasks from one request
 
 
@@ -186,22 +198,35 @@ public class ActionDBAccessorImpl implements ActionDBAccessor {
     endRequestIfCompleted(requestId);
     endRequestIfCompleted(requestId);
   }
   }
 
 
-  /* (non-Javadoc)
-   * @see org.apache.ambari.server.actionmanager.ActionDBAccessor#getPendingStages()
+  /**
+   * {@inheritDoc}
    */
    */
   @Override
   @Override
   public List<Stage> getStagesInProgress() {
   public List<Stage> getStagesInProgress() {
     List<Stage> stages = new ArrayList<Stage>();
     List<Stage> stages = new ArrayList<Stage>();
-    List<HostRoleStatus> statuses =
-        Arrays.asList(HostRoleStatus.QUEUED, HostRoleStatus.IN_PROGRESS,
-          HostRoleStatus.PENDING, HostRoleStatus.HOLDING,
-          HostRoleStatus.HOLDING_FAILED, HostRoleStatus.HOLDING_TIMEDOUT);
-    for (StageEntity stageEntity : stageDAO.findByCommandStatuses(statuses)) {
+
+    List<StageEntity> stageEntities = stageDAO.findByCommandStatuses(HostRoleStatus.IN_PROGRESS_STATUSES);
+
+    for (StageEntity stageEntity : stageEntities) {
       stages.add(stageFactory.createExisting(stageEntity));
       stages.add(stageFactory.createExisting(stageEntity));
     }
     }
+
     return stages;
     return stages;
   }
   }
 
 
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public int getCommandsInProgressCount() {
+    Number count = hostRoleCommandDAO.getCountByStatus(HostRoleStatus.IN_PROGRESS_STATUSES);
+    if (null == count) {
+      return 0;
+    }
+
+    return count.intValue();
+  }
+
   @Override
   @Override
   @Transactional
   @Transactional
   public void persistActions(Request request) throws AmbariException {
   public void persistActions(Request request) throws AmbariException {
@@ -562,16 +587,8 @@ public class ActionDBAccessorImpl implements ActionDBAccessor {
   public List<HostRoleCommand> getTasksByRoleAndStatus(String role, HostRoleStatus status) {
   public List<HostRoleCommand> getTasksByRoleAndStatus(String role, HostRoleStatus status) {
     return getTasks(hostRoleCommandDAO.findTaskIdsByRoleAndStatus(role, status));
     return getTasks(hostRoleCommandDAO.findTaskIdsByRoleAndStatus(role, status));
   }
   }
-  
-  @Override
-  public List<Stage> getStagesByHostRoleStatus(Set<HostRoleStatus> statuses) {
-    List<Stage> stages = new ArrayList<Stage>();
-    for (StageEntity stageEntity : stageDAO.findByCommandStatuses(statuses)) {
-      stages.add(stageFactory.createExisting(stageEntity));
-    }
-    return stages;
-  }
 
 
+  @Override
   public HostRoleCommand getTask(long taskId) {
   public HostRoleCommand getTask(long taskId) {
     HostRoleCommandEntity commandEntity = hostRoleCommandDAO.findByPK((int) taskId);
     HostRoleCommandEntity commandEntity = hostRoleCommandDAO.findByPK((int) taskId);
     if (commandEntity == null) {
     if (commandEntity == null) {

+ 12 - 17
ambari-server/src/main/java/org/apache/ambari/server/actionmanager/ActionManager.java

@@ -17,11 +17,13 @@
  */
  */
 package org.apache.ambari.server.actionmanager;
 package org.apache.ambari.server.actionmanager;
 
 
-import com.google.inject.Inject;
-import com.google.inject.Singleton;
-import com.google.inject.assistedinject.Assisted;
-import com.google.inject.name.Named;
-import com.google.inject.persist.UnitOfWork;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.agent.ActionQueue;
 import org.apache.ambari.server.agent.ActionQueue;
 import org.apache.ambari.server.agent.CommandReport;
 import org.apache.ambari.server.agent.CommandReport;
@@ -34,13 +36,10 @@ import org.apache.ambari.server.utils.StageUtils;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicLong;
+import com.google.inject.Inject;
+import com.google.inject.Singleton;
+import com.google.inject.name.Named;
+import com.google.inject.persist.UnitOfWork;
 
 
 
 
 /**
 /**
@@ -63,7 +62,7 @@ public class ActionManager {
                        UnitOfWork unitOfWork,
                        UnitOfWork unitOfWork,
                        RequestFactory requestFactory, Configuration configuration,
                        RequestFactory requestFactory, Configuration configuration,
                        AmbariEventPublisher ambariEventPublisher) {
                        AmbariEventPublisher ambariEventPublisher) {
-    this.actionQueue = aq;
+    actionQueue = aq;
     this.db = db;
     this.db = db;
     scheduler = new ActionScheduler(schedulerSleepTime, actionTimeout, db,
     scheduler = new ActionScheduler(schedulerSleepTime, actionTimeout, db,
         actionQueue, fsm, 2, hostsMap, unitOfWork, ambariEventPublisher, configuration);
         actionQueue, fsm, 2, hostsMap, unitOfWork, ambariEventPublisher, configuration);
@@ -205,10 +204,6 @@ public class ActionManager {
     return db.getTasks(taskIds);
     return db.getTasks(taskIds);
   }
   }
 
 
-  public List<Stage> getRequestsByHostRoleStatus(Set<HostRoleStatus> statuses) {
-    return db.getStagesByHostRoleStatus(statuses);
-  }
-
   /**
   /**
    * Get first or last maxResults requests that are in the specified status
    * Get first or last maxResults requests that are in the specified status
    *
    *

+ 27 - 11
ambari-server/src/main/java/org/apache/ambari/server/actionmanager/ActionScheduler.java

@@ -129,23 +129,23 @@ class ActionScheduler implements Runnable {
                          int maxAttempts, HostsMap hostsMap,
                          int maxAttempts, HostsMap hostsMap,
                          UnitOfWork unitOfWork, AmbariEventPublisher ambariEventPublisher,
                          UnitOfWork unitOfWork, AmbariEventPublisher ambariEventPublisher,
                          Configuration configuration) {
                          Configuration configuration) {
-    this.sleepTime = sleepTimeMilliSec;
+    sleepTime = sleepTimeMilliSec;
     this.hostsMap = hostsMap;
     this.hostsMap = hostsMap;
-    this.actionTimeout = actionTimeoutMilliSec;
+    actionTimeout = actionTimeoutMilliSec;
     this.db = db;
     this.db = db;
     this.actionQueue = actionQueue;
     this.actionQueue = actionQueue;
     this.fsmObject = fsmObject;
     this.fsmObject = fsmObject;
     this.ambariEventPublisher = ambariEventPublisher;
     this.ambariEventPublisher = ambariEventPublisher;
     this.maxAttempts = (short) maxAttempts;
     this.maxAttempts = (short) maxAttempts;
-    this.serverActionExecutor = new ServerActionExecutor(db, sleepTimeMilliSec);
+    serverActionExecutor = new ServerActionExecutor(db, sleepTimeMilliSec);
     this.unitOfWork = unitOfWork;
     this.unitOfWork = unitOfWork;
-    this.clusterHostInfoCache = CacheBuilder.newBuilder().
+    clusterHostInfoCache = CacheBuilder.newBuilder().
         expireAfterAccess(5, TimeUnit.MINUTES).
         expireAfterAccess(5, TimeUnit.MINUTES).
         build();
         build();
-    this.commandParamsStageCache = CacheBuilder.newBuilder().
+    commandParamsStageCache = CacheBuilder.newBuilder().
       expireAfterAccess(5, TimeUnit.MINUTES).
       expireAfterAccess(5, TimeUnit.MINUTES).
       build();
       build();
-    this.hostParamsStageCache = CacheBuilder.newBuilder().
+    hostParamsStageCache = CacheBuilder.newBuilder().
       expireAfterAccess(5, TimeUnit.MINUTES).
       expireAfterAccess(5, TimeUnit.MINUTES).
       build();
       build();
     this.configuration = configuration;
     this.configuration = configuration;
@@ -212,19 +212,34 @@ class ActionScheduler implements Runnable {
       // The first thing to do is to abort requests that are cancelled
       // The first thing to do is to abort requests that are cancelled
       processCancelledRequestsList();
       processCancelledRequestsList();
 
 
+      // !!! getting the stages in progress could be a very expensive call due
+      // to the join being used; there's no need to make it if there are
+      // no commands in progress
+      if (db.getCommandsInProgressCount() == 0) {
+        // Nothing to do
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("There are no stages currently in progress.");
+        }
+
+        return;
+      }
+
       Set<Long> runningRequestIds = new HashSet<Long>();
       Set<Long> runningRequestIds = new HashSet<Long>();
       List<Stage> stages = db.getStagesInProgress();
       List<Stage> stages = db.getStagesInProgress();
       if (LOG.isDebugEnabled()) {
       if (LOG.isDebugEnabled()) {
         LOG.debug("Scheduler wakes up");
         LOG.debug("Scheduler wakes up");
         LOG.debug("Processing {} in progress stages ", stages.size());
         LOG.debug("Processing {} in progress stages ", stages.size());
       }
       }
+
       if (stages.isEmpty()) {
       if (stages.isEmpty()) {
-        //Nothing to do
+        // Nothing to do
         if (LOG.isDebugEnabled()) {
         if (LOG.isDebugEnabled()) {
-          LOG.debug("No stage in progress..nothing to do");
+          LOG.debug("There are no stages currently in progress.");
         }
         }
+
         return;
         return;
       }
       }
+
       int i_stage = 0;
       int i_stage = 0;
 
 
       stages = filterParallelPerHostStages(stages);
       stages = filterParallelPerHostStages(stages);
@@ -590,7 +605,7 @@ class ActionScheduler implements Runnable {
           LOG.trace("===>commandsToSchedule(first_time)=" + commandsToSchedule.size());
           LOG.trace("===>commandsToSchedule(first_time)=" + commandsToSchedule.size());
         }
         }
 
 
-        this.updateRoleStats(status, roleStats.get(roleStr));
+        updateRoleStats(status, roleStats.get(roleStr));
       }
       }
     }
     }
     LOG.debug("Collected {} commands to schedule in this wakeup.", commandsToSchedule.size());
     LOG.debug("Collected {} commands to schedule in this wakeup.", commandsToSchedule.size());
@@ -912,7 +927,7 @@ class ActionScheduler implements Runnable {
 
 
 
 
   public void setTaskTimeoutAdjustment(boolean val) {
   public void setTaskTimeoutAdjustment(boolean val) {
-    this.taskTimeoutAdjustment = val;
+    taskTimeoutAdjustment = val;
   }
   }
 
 
   ServerActionExecutor getServerActionExecutor() {
   ServerActionExecutor getServerActionExecutor() {
@@ -932,7 +947,7 @@ class ActionScheduler implements Runnable {
     final float successFactor;
     final float successFactor;
 
 
     RoleStats(int total, float successFactor) {
     RoleStats(int total, float successFactor) {
-      this.totalHosts = total;
+      totalHosts = total;
       this.successFactor = successFactor;
       this.successFactor = successFactor;
     }
     }
 
 
@@ -956,6 +971,7 @@ class ActionScheduler implements Runnable {
       return !(isRoleInProgress() || isSuccessFactorMet());
       return !(isRoleInProgress() || isSuccessFactorMet());
     }
     }
 
 
+    @Override
     public String toString() {
     public String toString() {
       StringBuilder builder = new StringBuilder();
       StringBuilder builder = new StringBuilder();
       builder.append("numQueued=").append(numQueued);
       builder.append("numQueued=").append(numQueued);

+ 9 - 0
ambari-server/src/main/java/org/apache/ambari/server/actionmanager/HostRoleStatus.java

@@ -19,6 +19,7 @@ package org.apache.ambari.server.actionmanager;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.List;
 
 
 public enum HostRoleStatus {
 public enum HostRoleStatus {
@@ -67,6 +68,14 @@ public enum HostRoleStatus {
   private static List<HostRoleStatus> FAILED_STATES = Arrays.asList(FAILED, TIMEDOUT, ABORTED);
   private static List<HostRoleStatus> FAILED_STATES = Arrays.asList(FAILED, TIMEDOUT, ABORTED);
   private static List<HostRoleStatus> HOLDING_STATES = Arrays.asList(HOLDING, HOLDING_FAILED, HOLDING_TIMEDOUT);
   private static List<HostRoleStatus> HOLDING_STATES = Arrays.asList(HOLDING, HOLDING_FAILED, HOLDING_TIMEDOUT);
 
 
+  /**
+   * The {@link HostRoleStatus}s that represent any commands which are
+   * considered to be "In Progress".
+   */
+  public static final EnumSet<HostRoleStatus> IN_PROGRESS_STATUSES = EnumSet.of(
+      HostRoleStatus.QUEUED, HostRoleStatus.IN_PROGRESS,
+      HostRoleStatus.PENDING, HostRoleStatus.HOLDING,
+      HostRoleStatus.HOLDING_FAILED, HostRoleStatus.HOLDING_TIMEDOUT);
 
 
   /**
   /**
    * Indicates whether or not it is a valid failure state.
    * Indicates whether or not it is a valid failure state.

+ 54 - 16
ambari-server/src/main/java/org/apache/ambari/server/orm/dao/HostRoleCommandDAO.java

@@ -18,26 +18,30 @@
 
 
 package org.apache.ambari.server.orm.dao;
 package org.apache.ambari.server.orm.dao;
 
 
-import com.google.common.collect.Lists;
-import com.google.inject.Inject;
-import com.google.inject.Provider;
-import com.google.inject.Singleton;
-import com.google.inject.persist.Transactional;
-import org.apache.ambari.server.actionmanager.HostRoleStatus;
-import org.apache.ambari.server.orm.RequiresSession;
-import org.apache.ambari.server.orm.entities.HostEntity;
-import org.apache.ambari.server.orm.entities.HostRoleCommandEntity;
-import org.apache.ambari.server.orm.entities.StageEntity;
-import javax.persistence.EntityManager;
-import javax.persistence.TypedQuery;
+import static org.apache.ambari.server.orm.DBAccessor.DbType.ORACLE;
+import static org.apache.ambari.server.orm.dao.DaoUtils.ORACLE_LIST_LIMIT;
+
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
-import static org.apache.ambari.server.orm.DBAccessor.DbType.ORACLE;
-import static org.apache.ambari.server.orm.dao.DaoUtils.ORACLE_LIST_LIMIT;
+
+import javax.persistence.EntityManager;
+import javax.persistence.TypedQuery;
+
+import org.apache.ambari.server.actionmanager.HostRoleStatus;
+import org.apache.ambari.server.orm.RequiresSession;
+import org.apache.ambari.server.orm.entities.HostEntity;
+import org.apache.ambari.server.orm.entities.HostRoleCommandEntity;
+import org.apache.ambari.server.orm.entities.StageEntity;
+
+import com.google.common.collect.Lists;
+import com.google.inject.Inject;
+import com.google.inject.Provider;
+import com.google.inject.Singleton;
+import com.google.inject.persist.Transactional;
 
 
 @Singleton
 @Singleton
 public class HostRoleCommandDAO {
 public class HostRoleCommandDAO {
@@ -145,8 +149,8 @@ public class HostRoleCommandDAO {
 
 
     return daoUtils.selectList(query, role, status);
     return daoUtils.selectList(query, role, status);
   }
   }
-  
-  
+
+
   @RequiresSession
   @RequiresSession
   public List<HostRoleCommandEntity> findSortedCommandsByStageAndHost(StageEntity stageEntity, HostEntity hostEntity) {
   public List<HostRoleCommandEntity> findSortedCommandsByStageAndHost(StageEntity stageEntity, HostEntity hostEntity) {
     TypedQuery<HostRoleCommandEntity> query = entityManagerProvider.get().createQuery("SELECT hostRoleCommand " +
     TypedQuery<HostRoleCommandEntity> query = entityManagerProvider.get().createQuery("SELECT hostRoleCommand " +
@@ -215,6 +219,40 @@ public class HostRoleCommandDAO {
     return daoUtils.selectList(query, requestId);
     return daoUtils.selectList(query, requestId);
   }
   }
 
 
+  /**
+   * Gets the commands in a particular status.
+   *
+   * @param statuses
+   *          the statuses to include (not {@code null}).
+   * @return the commands in the given set of statuses.
+   */
+  @RequiresSession
+  public List<HostRoleCommandEntity> findByStatus(
+      Collection<HostRoleStatus> statuses) {
+    TypedQuery<HostRoleCommandEntity> query = entityManagerProvider.get().createNamedQuery(
+        "HostRoleCommandEntity.findByCommandStatuses",
+        HostRoleCommandEntity.class);
+
+    query.setParameter("statuses", statuses);
+    return daoUtils.selectList(query);
+  }
+
+  /**
+   * Gets the number of commands in a particular status.
+   *
+   * @param statuses
+   *          the statuses to include (not {@code null}).
+   * @return the count of commands in the given set of statuses.
+   */
+  @RequiresSession
+  public Number getCountByStatus(Collection<HostRoleStatus> statuses) {
+    TypedQuery<Number> query = entityManagerProvider.get().createNamedQuery(
+        "HostRoleCommandEntity.findCountByCommandStatuses", Number.class);
+
+    query.setParameter("statuses", statuses);
+    return daoUtils.selectSingle(query);
+  }
+
   @RequiresSession
   @RequiresSession
   public List<HostRoleCommandEntity> findAll() {
   public List<HostRoleCommandEntity> findAll() {
     return daoUtils.selectAll(entityManagerProvider.get(), HostRoleCommandEntity.class);
     return daoUtils.selectAll(entityManagerProvider.get(), HostRoleCommandEntity.class);

+ 28 - 23
ambari-server/src/main/java/org/apache/ambari/server/orm/dao/StageDAO.java

@@ -18,10 +18,19 @@
 
 
 package org.apache.ambari.server.orm.dao;
 package org.apache.ambari.server.orm.dao;
 
 
-import com.google.inject.Inject;
-import com.google.inject.Provider;
-import com.google.inject.Singleton;
-import com.google.inject.persist.Transactional;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.persistence.EntityManager;
+import javax.persistence.TypedQuery;
+import javax.persistence.criteria.CriteriaQuery;
+import javax.persistence.criteria.Order;
+import javax.persistence.metamodel.SingularAttribute;
+
 import org.apache.ambari.server.actionmanager.HostRoleStatus;
 import org.apache.ambari.server.actionmanager.HostRoleStatus;
 import org.apache.ambari.server.api.query.JpaPredicateVisitor;
 import org.apache.ambari.server.api.query.JpaPredicateVisitor;
 import org.apache.ambari.server.api.query.JpaSortBuilder;
 import org.apache.ambari.server.api.query.JpaSortBuilder;
@@ -36,17 +45,10 @@ import org.apache.ambari.server.utils.StageUtils;
 import org.eclipse.persistence.config.HintValues;
 import org.eclipse.persistence.config.HintValues;
 import org.eclipse.persistence.config.QueryHints;
 import org.eclipse.persistence.config.QueryHints;
 
 
-import javax.persistence.EntityManager;
-import javax.persistence.TypedQuery;
-import javax.persistence.criteria.CriteriaQuery;
-import javax.persistence.criteria.Order;
-import javax.persistence.metamodel.SingularAttribute;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Collection;
-import java.util.Map;
-import java.util.Set;
+import com.google.inject.Inject;
+import com.google.inject.Provider;
+import com.google.inject.Singleton;
+import com.google.inject.persist.Transactional;
 
 
 @Singleton
 @Singleton
 public class StageDAO {
 public class StageDAO {
@@ -115,12 +117,13 @@ public class StageDAO {
   }
   }
 
 
   @RequiresSession
   @RequiresSession
-  public List<StageEntity> findByCommandStatuses(Collection<HostRoleStatus> statuses) {
-    TypedQuery<StageEntity> query = entityManagerProvider.get().createQuery("SELECT stage " +
-          "FROM StageEntity stage WHERE stage.stageId IN (SELECT hrce.stageId FROM " +
-          "HostRoleCommandEntity hrce WHERE stage.requestId = hrce.requestId and hrce.status IN ?1 ) " +
-          "ORDER BY stage.requestId, stage.stageId", StageEntity.class);
-    return daoUtils.selectList(query, statuses);
+  public List<StageEntity> findByCommandStatuses(
+      Collection<HostRoleStatus> statuses) {
+    TypedQuery<StageEntity> query = entityManagerProvider.get().createNamedQuery(
+        "StageEntity.findByCommandStatuses", StageEntity.class);
+
+    query.setParameter("statuses", statuses);
+    return daoUtils.selectList(query);
   }
   }
 
 
   @RequiresSession
   @RequiresSession
@@ -147,10 +150,12 @@ public class StageDAO {
       "SELECT stage.requestContext " + "FROM StageEntity stage " +
       "SELECT stage.requestContext " + "FROM StageEntity stage " +
         "WHERE stage.requestId=?1", String.class);
         "WHERE stage.requestId=?1", String.class);
     String result =  daoUtils.selectOne(query, requestId);
     String result =  daoUtils.selectOne(query, requestId);
-    if (result != null)
+    if (result != null) {
       return result;
       return result;
-    else
+    }
+    else {
       return ""; // Since it is defined as empty string in the StageEntity
       return ""; // Since it is defined as empty string in the StageEntity
+    }
   }
   }
 
 
   @Transactional
   @Transactional

+ 65 - 24
ambari-server/src/main/java/org/apache/ambari/server/orm/entities/HostRoleCommandEntity.java

@@ -36,6 +36,8 @@ import javax.persistence.JoinColumn;
 import javax.persistence.JoinColumns;
 import javax.persistence.JoinColumns;
 import javax.persistence.Lob;
 import javax.persistence.Lob;
 import javax.persistence.ManyToOne;
 import javax.persistence.ManyToOne;
+import javax.persistence.NamedQueries;
+import javax.persistence.NamedQuery;
 import javax.persistence.OneToOne;
 import javax.persistence.OneToOne;
 import javax.persistence.Table;
 import javax.persistence.Table;
 import javax.persistence.TableGenerator;
 import javax.persistence.TableGenerator;
@@ -45,15 +47,17 @@ import org.apache.ambari.server.RoleCommand;
 import org.apache.ambari.server.actionmanager.HostRoleStatus;
 import org.apache.ambari.server.actionmanager.HostRoleStatus;
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang.ArrayUtils;
 
 
-@Table(name = "host_role_command")
 @Entity
 @Entity
+@Table(name = "host_role_command")
 @TableGenerator(name = "host_role_command_id_generator",
 @TableGenerator(name = "host_role_command_id_generator",
     table = "ambari_sequences", pkColumnName = "sequence_name", valueColumnName = "sequence_value"
     table = "ambari_sequences", pkColumnName = "sequence_name", valueColumnName = "sequence_value"
     , pkColumnValue = "host_role_command_id_seq"
     , pkColumnValue = "host_role_command_id_seq"
     , initialValue = 1
     , initialValue = 1
     , allocationSize = 50
     , allocationSize = 50
 )
 )
-
+@NamedQueries({
+    @NamedQuery(name = "HostRoleCommandEntity.findCountByCommandStatuses", query = "SELECT COUNT(command.taskId) FROM HostRoleCommandEntity command WHERE command.status IN :statuses"),
+    @NamedQuery(name = "HostRoleCommandEntity.findByCommandStatuses", query = "SELECT command FROM HostRoleCommandEntity command WHERE command.status IN :statuses ORDER BY command.requestId, command.stageId") })
 public class HostRoleCommandEntity {
 public class HostRoleCommandEntity {
 
 
   private static int MAX_COMMAND_DETAIL_LENGTH = 250;
   private static int MAX_COMMAND_DETAIL_LENGTH = 250;
@@ -193,7 +197,7 @@ public class HostRoleCommandEntity {
   }
   }
 
 
   public Role getRole() {
   public Role getRole() {
-    return Role.valueOf(this.role);
+    return Role.valueOf(role);
   }
   }
 
 
   public void setRole(Role role) {
   public void setRole(Role role) {
@@ -333,34 +337,71 @@ public class HostRoleCommandEntity {
    * @param enabled  {@code true} if this task should hold for retry when an error occurs
    * @param enabled  {@code true} if this task should hold for retry when an error occurs
    */
    */
   public void setRetryAllowed(boolean enabled) {
   public void setRetryAllowed(boolean enabled) {
-    this.retryAllowed = enabled ? 1 : 0;
+    retryAllowed = enabled ? 1 : 0;
   }
   }
 
 
   @Override
   @Override
   public boolean equals(Object o) {
   public boolean equals(Object o) {
-    if (this == o) return true;
-    if (o == null || getClass() != o.getClass()) return false;
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
 
 
     HostRoleCommandEntity that = (HostRoleCommandEntity) o;
     HostRoleCommandEntity that = (HostRoleCommandEntity) o;
 
 
-    if (attemptCount != null ? !attemptCount.equals(that.attemptCount) : that.attemptCount != null) return false;
-    if (event != null ? !event.equals(that.event) : that.event != null) return false;
-    if (exitcode != null ? !exitcode.equals(that.exitcode) : that.exitcode != null) return false;
-    if (hostName != null ? !hostName.equals(that.hostName) : that.hostName != null) return false;
-    if (lastAttemptTime != null ? !lastAttemptTime.equals(that.lastAttemptTime) : that.lastAttemptTime != null)
+    if (attemptCount != null ? !attemptCount.equals(that.attemptCount) : that.attemptCount != null) {
+      return false;
+    }
+    if (event != null ? !event.equals(that.event) : that.event != null) {
+      return false;
+    }
+    if (exitcode != null ? !exitcode.equals(that.exitcode) : that.exitcode != null) {
+      return false;
+    }
+    if (hostName != null ? !hostName.equals(that.hostName) : that.hostName != null) {
+      return false;
+    }
+    if (lastAttemptTime != null ? !lastAttemptTime.equals(that.lastAttemptTime) : that.lastAttemptTime != null) {
+      return false;
+    }
+    if (requestId != null ? !requestId.equals(that.requestId) : that.requestId != null) {
+      return false;
+    }
+    if (role != null ? !role.equals(that.role) : that.role != null) {
+      return false;
+    }
+    if (stageId != null ? !stageId.equals(that.stageId) : that.stageId != null) {
       return false;
       return false;
-    if (requestId != null ? !requestId.equals(that.requestId) : that.requestId != null) return false;
-    if (role != null ? !role.equals(that.role) : that.role != null) return false;
-    if (stageId != null ? !stageId.equals(that.stageId) : that.stageId != null) return false;
-    if (startTime != null ? !startTime.equals(that.startTime) : that.startTime != null) return false;
-    if (status != null ? !status.equals(that.status) : that.status != null) return false;
-    if (stdError != null ? !Arrays.equals(stdError, that.stdError) : that.stdError != null) return false;
-    if (stdOut != null ? !Arrays.equals(stdOut, that.stdOut) : that.stdOut != null) return false;
-    if (outputLog != null ? !outputLog.equals(that.outputLog) : that.outputLog != null) return false;
-    if (errorLog != null ? !errorLog.equals(that.errorLog) : that.errorLog != null) return false;
-    if (taskId != null ? !taskId.equals(that.taskId) : that.taskId != null) return false;
-    if (structuredOut != null ? !Arrays.equals(structuredOut, that.structuredOut) : that.structuredOut != null) return false;
-    if (endTime != null ? !endTime.equals(that.endTime) : that.endTime != null) return false;
+    }
+    if (startTime != null ? !startTime.equals(that.startTime) : that.startTime != null) {
+      return false;
+    }
+    if (status != null ? !status.equals(that.status) : that.status != null) {
+      return false;
+    }
+    if (stdError != null ? !Arrays.equals(stdError, that.stdError) : that.stdError != null) {
+      return false;
+    }
+    if (stdOut != null ? !Arrays.equals(stdOut, that.stdOut) : that.stdOut != null) {
+      return false;
+    }
+    if (outputLog != null ? !outputLog.equals(that.outputLog) : that.outputLog != null) {
+      return false;
+    }
+    if (errorLog != null ? !errorLog.equals(that.errorLog) : that.errorLog != null) {
+      return false;
+    }
+    if (taskId != null ? !taskId.equals(that.taskId) : that.taskId != null) {
+      return false;
+    }
+    if (structuredOut != null ? !Arrays.equals(structuredOut, that.structuredOut) : that.structuredOut != null) {
+      return false;
+    }
+    if (endTime != null ? !endTime.equals(that.endTime) : that.endTime != null) {
+      return false;
+    }
 
 
     return true;
     return true;
   }
   }
@@ -392,7 +433,7 @@ public class HostRoleCommandEntity {
   }
   }
 
 
   public void setExecutionCommand(ExecutionCommandEntity executionCommandsByTaskId) {
   public void setExecutionCommand(ExecutionCommandEntity executionCommandsByTaskId) {
-    this.executionCommand = executionCommandsByTaskId;
+    executionCommand = executionCommandsByTaskId;
   }
   }
 
 
   public StageEntity getStage() {
   public StageEntity getStage() {

+ 57 - 18
ambari-server/src/main/java/org/apache/ambari/server/orm/entities/StageEntity.java

@@ -18,14 +18,28 @@
 
 
 package org.apache.ambari.server.orm.entities;
 package org.apache.ambari.server.orm.entities;
 
 
-import javax.persistence.*;
+import static org.apache.commons.lang.StringUtils.defaultString;
+
 import java.util.Collection;
 import java.util.Collection;
 
 
-import static org.apache.commons.lang.StringUtils.defaultString;
+import javax.persistence.Basic;
+import javax.persistence.CascadeType;
+import javax.persistence.Column;
+import javax.persistence.Entity;
+import javax.persistence.FetchType;
+import javax.persistence.Id;
+import javax.persistence.IdClass;
+import javax.persistence.JoinColumn;
+import javax.persistence.ManyToOne;
+import javax.persistence.NamedQueries;
+import javax.persistence.NamedQuery;
+import javax.persistence.OneToMany;
+import javax.persistence.Table;
 
 
-@IdClass(org.apache.ambari.server.orm.entities.StageEntityPK.class)
-@Table(name = "stage")
 @Entity
 @Entity
+@Table(name = "stage")
+@IdClass(org.apache.ambari.server.orm.entities.StageEntityPK.class)
+@NamedQueries({ @NamedQuery(name = "StageEntity.findByCommandStatuses", query = "SELECT stage from StageEntity stage WHERE EXISTS (SELECT roleCommand.stageId from HostRoleCommandEntity roleCommand WHERE roleCommand.status IN :statuses AND roleCommand.stageId = stage.stageId AND roleCommand.requestId = stage.requestId ) ORDER by stage.requestId, stage.stageId") })
 public class StageEntity {
 public class StageEntity {
 
 
   @Column(name = "cluster_id", updatable = false, nullable = false)
   @Column(name = "cluster_id", updatable = false, nullable = false)
@@ -50,11 +64,11 @@ public class StageEntity {
   @Column(name = "request_context")
   @Column(name = "request_context")
   @Basic
   @Basic
   private String requestContext = "";
   private String requestContext = "";
-  
+
   @Column(name = "cluster_host_info")
   @Column(name = "cluster_host_info")
   @Basic
   @Basic
   private byte[] clusterHostInfo;
   private byte[] clusterHostInfo;
- 
+
   @Column(name = "command_params")
   @Column(name = "command_params")
   @Basic
   @Basic
   private byte[] commandParamsStage;
   private byte[] commandParamsStage;
@@ -66,7 +80,7 @@ public class StageEntity {
   @ManyToOne
   @ManyToOne
   @JoinColumn(name = "request_id", referencedColumnName = "request_id", nullable = false)
   @JoinColumn(name = "request_id", referencedColumnName = "request_id", nullable = false)
   private RequestEntity request;
   private RequestEntity request;
-  
+
 
 
   @OneToMany(mappedBy = "stage", cascade = CascadeType.REMOVE, fetch = FetchType.LAZY)
   @OneToMany(mappedBy = "stage", cascade = CascadeType.REMOVE, fetch = FetchType.LAZY)
   private Collection<HostRoleCommandEntity> hostRoleCommands;
   private Collection<HostRoleCommandEntity> hostRoleCommands;
@@ -117,7 +131,7 @@ public class StageEntity {
   public void setClusterHostInfo(String clusterHostInfo) {
   public void setClusterHostInfo(String clusterHostInfo) {
     this.clusterHostInfo = clusterHostInfo.getBytes();
     this.clusterHostInfo = clusterHostInfo.getBytes();
   }
   }
- 
+
   public String getCommandParamsStage() {
   public String getCommandParamsStage() {
     return commandParamsStage == null ? new String() : new String(commandParamsStage);
     return commandParamsStage == null ? new String() : new String(commandParamsStage);
   }
   }
@@ -142,20 +156,45 @@ public class StageEntity {
 
 
   @Override
   @Override
   public boolean equals(Object o) {
   public boolean equals(Object o) {
-    if (this == o) return true;
-    if (o == null || getClass() != o.getClass()) return false;
+    if (this == o) {
+      return true;
+    }
+
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
 
 
     StageEntity that = (StageEntity) o;
     StageEntity that = (StageEntity) o;
 
 
-    if (clusterId != null ? !clusterId.equals(that.clusterId) : that.clusterId != null) return false;
-    if (logInfo != null ? !logInfo.equals(that.logInfo) : that.logInfo != null) return false;
-    if (requestId != null ? !requestId.equals(that.requestId) : that.requestId != null) return false;
-    if (stageId != null ? !stageId.equals(that.stageId) : that.stageId != null) return false;
-    if (clusterHostInfo != null ? !clusterHostInfo.equals(that.clusterHostInfo) : that.clusterHostInfo != null) return false;
-    if (commandParamsStage != null ? !commandParamsStage.equals(that.commandParamsStage) : that.commandParamsStage != null) return false;
-    if (hostParamsStage != null ? !hostParamsStage.equals(that.hostParamsStage) : that.hostParamsStage != null) return false;
-    return !(requestContext != null ? !requestContext.equals(that.requestContext) : that.requestContext != null);
+    if (clusterId != null ? !clusterId.equals(that.clusterId) : that.clusterId != null) {
+      return false;
+    }
+
+    if (requestId != null ? !requestId.equals(that.requestId) : that.requestId != null) {
+      return false;
+    }
+
+    if (stageId != null ? !stageId.equals(that.stageId) : that.stageId != null) {
+      return false;
+    }
+
+    if (logInfo != null ? !logInfo.equals(that.logInfo) : that.logInfo != null) {
+      return false;
+    }
 
 
+    if (clusterHostInfo != null ? !clusterHostInfo.equals(that.clusterHostInfo) : that.clusterHostInfo != null) {
+      return false;
+    }
+
+    if (commandParamsStage != null ? !commandParamsStage.equals(that.commandParamsStage) : that.commandParamsStage != null) {
+      return false;
+    }
+
+    if (hostParamsStage != null ? !hostParamsStage.equals(that.hostParamsStage) : that.hostParamsStage != null) {
+      return false;
+    }
+
+    return !(requestContext != null ? !requestContext.equals(that.requestContext) : that.requestContext != null);
   }
   }
 
 
   @Override
   @Override

+ 130 - 23
ambari-server/src/test/java/org/apache/ambari/server/actionmanager/TestActionDBAccessorImpl.java

@@ -17,15 +17,20 @@
  */
  */
 package org.apache.ambari.server.actionmanager;
 package org.apache.ambari.server.actionmanager;
 
 
-import com.google.inject.AbstractModule;
-import com.google.inject.Guice;
-import com.google.inject.Inject;
-import com.google.inject.Injector;
-import com.google.inject.Singleton;
-import com.google.inject.persist.PersistService;
-import com.google.inject.persist.UnitOfWork;
-import com.google.inject.util.Modules;
+import static org.apache.ambari.server.orm.DBAccessor.DbType.ORACLE;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import javax.persistence.EntityManager;
+
 import junit.framework.Assert;
 import junit.framework.Assert;
+
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.Role;
 import org.apache.ambari.server.Role;
 import org.apache.ambari.server.RoleCommand;
 import org.apache.ambari.server.RoleCommand;
@@ -40,6 +45,7 @@ import org.apache.ambari.server.orm.DBAccessor;
 import org.apache.ambari.server.orm.DBAccessorImpl;
 import org.apache.ambari.server.orm.DBAccessorImpl;
 import org.apache.ambari.server.orm.GuiceJpaInitializer;
 import org.apache.ambari.server.orm.GuiceJpaInitializer;
 import org.apache.ambari.server.orm.InMemoryDefaultTestModule;
 import org.apache.ambari.server.orm.InMemoryDefaultTestModule;
+import org.apache.ambari.server.orm.dao.DaoUtils;
 import org.apache.ambari.server.orm.dao.ExecutionCommandDAO;
 import org.apache.ambari.server.orm.dao.ExecutionCommandDAO;
 import org.apache.ambari.server.orm.dao.HostRoleCommandDAO;
 import org.apache.ambari.server.orm.dao.HostRoleCommandDAO;
 import org.apache.ambari.server.orm.entities.HostRoleCommandEntity;
 import org.apache.ambari.server.orm.entities.HostRoleCommandEntity;
@@ -52,14 +58,16 @@ import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.apache.ambari.server.orm.DBAccessor.DbType.ORACLE;
+
+import com.google.inject.AbstractModule;
+import com.google.inject.Guice;
+import com.google.inject.Inject;
+import com.google.inject.Injector;
+import com.google.inject.Provider;
+import com.google.inject.Singleton;
+import com.google.inject.persist.PersistService;
+import com.google.inject.persist.UnitOfWork;
+import com.google.inject.util.Modules;
 
 
 public class TestActionDBAccessorImpl {
 public class TestActionDBAccessorImpl {
   private static final Logger log = LoggerFactory.getLogger(TestActionDBAccessorImpl.class);
   private static final Logger log = LoggerFactory.getLogger(TestActionDBAccessorImpl.class);
@@ -79,11 +87,19 @@ public class TestActionDBAccessorImpl {
 
 
   @Inject
   @Inject
   private Clusters clusters;
   private Clusters clusters;
+
   @Inject
   @Inject
   private ExecutionCommandDAO executionCommandDAO;
   private ExecutionCommandDAO executionCommandDAO;
+
   @Inject
   @Inject
   private HostRoleCommandDAO hostRoleCommandDAO;
   private HostRoleCommandDAO hostRoleCommandDAO;
 
 
+  @Inject
+  private Provider<EntityManager> entityManagerProvider;
+
+  @Inject
+  private DaoUtils daoUtils;
+
   @Before
   @Before
   public void setup() throws AmbariException {
   public void setup() throws AmbariException {
     InMemoryDefaultTestModule defaultTestModule = new InMemoryDefaultTestModule();
     InMemoryDefaultTestModule defaultTestModule = new InMemoryDefaultTestModule();
@@ -171,10 +187,9 @@ public class TestActionDBAccessorImpl {
 
 
   @Test
   @Test
   public void testGetStagesInProgress() throws AmbariException {
   public void testGetStagesInProgress() throws AmbariException {
-    String hostname = "host1";
     List<Stage> stages = new ArrayList<Stage>();
     List<Stage> stages = new ArrayList<Stage>();
-    stages.add(createStubStage(hostname, requestId, stageId));
-    stages.add(createStubStage(hostname, requestId, stageId + 1));
+    stages.add(createStubStage(hostName, requestId, stageId));
+    stages.add(createStubStage(hostName, requestId, stageId + 1));
     Request request = new Request(stages, clusters);
     Request request = new Request(stages, clusters);
     db.persistActions(request);
     db.persistActions(request);
     assertEquals(2, stages.size());
     assertEquals(2, stages.size());
@@ -182,15 +197,93 @@ public class TestActionDBAccessorImpl {
 
 
   @Test
   @Test
   public void testGetStagesInProgressWithFailures() throws AmbariException {
   public void testGetStagesInProgressWithFailures() throws AmbariException {
-    String hostname = "host1";
-    populateActionDB(db, hostname, requestId, stageId);
-    populateActionDB(db, hostname, requestId+1, stageId);
-    db.abortOperation(requestId);
+    populateActionDB(db, hostName, requestId, stageId);
+    populateActionDB(db, hostName, requestId + 1, stageId);
     List<Stage> stages = db.getStagesInProgress();
     List<Stage> stages = db.getStagesInProgress();
+    assertEquals(2, stages.size());
+
+    db.abortOperation(requestId);
+    stages = db.getStagesInProgress();
     assertEquals(1, stages.size());
     assertEquals(1, stages.size());
     assertEquals(requestId+1, stages.get(0).getRequestId());
     assertEquals(requestId+1, stages.get(0).getRequestId());
   }
   }
 
 
+  @Test
+  public void testGetStagesInProgressWithManyStages() throws AmbariException {
+    // create 3 request; each request will have 3 stages, each stage 2 commands
+    populateActionDBMultipleStages(3, db, hostName, requestId, stageId);
+    populateActionDBMultipleStages(3, db, hostName, requestId + 1, stageId + 3);
+    populateActionDBMultipleStages(3, db, hostName, requestId + 2, stageId + 3);
+
+    // verify stages and proper ordering
+    int commandsInProgressCount = db.getCommandsInProgressCount();
+    List<Stage> stages = db.getStagesInProgress();
+    assertEquals(18, commandsInProgressCount);
+    assertEquals(9, stages.size());
+
+    long lastRequestId = Integer.MIN_VALUE;
+    for (Stage stage : stages) {
+      assertTrue(stage.getRequestId() >= lastRequestId);
+      lastRequestId = stage.getRequestId();
+    }
+
+    // cancel the first one, removing 3 stages
+    db.abortOperation(requestId);
+
+    // verify stages and proper ordering
+    commandsInProgressCount = db.getCommandsInProgressCount();
+    stages = db.getStagesInProgress();
+    assertEquals(12, commandsInProgressCount);
+    assertEquals(6, stages.size());
+
+    // find the first stage, and change one command to COMPLETED
+    stages.get(0).setHostRoleStatus(hostName, Role.HBASE_MASTER.toString(),
+        HostRoleStatus.COMPLETED);
+
+    db.hostRoleScheduled(stages.get(0), hostName, Role.HBASE_MASTER.toString());
+
+    // the first stage still has at least 1 command IN_PROGRESS
+    commandsInProgressCount = db.getCommandsInProgressCount();
+    stages = db.getStagesInProgress();
+    assertEquals(11, commandsInProgressCount);
+    assertEquals(6, stages.size());
+
+    // find the first stage, and change the other command to COMPLETED
+    stages.get(0).setHostRoleStatus(hostName,
+        Role.HBASE_REGIONSERVER.toString(), HostRoleStatus.COMPLETED);
+
+    db.hostRoleScheduled(stages.get(0), hostName,
+        Role.HBASE_REGIONSERVER.toString());
+
+    // verify stages and proper ordering
+    commandsInProgressCount = db.getCommandsInProgressCount();
+    stages = db.getStagesInProgress();
+    assertEquals(10, commandsInProgressCount);
+    assertEquals(5, stages.size());
+  }
+
+  @Test
+  public void testGetStagesInProgressWithManyCommands() throws AmbariException {
+    // 1000 hosts
+    for (int i = 0; i < 1000; i++) {
+      String hostName = "c64-" + i;
+      clusters.addHost(hostName);
+      clusters.getHost(hostName).persist();
+    }
+
+    // create 1 request, 3 stages per host, each with 2 commands
+    for (int i = 0; i < 1000; i++) {
+      String hostName = "c64-" + i;
+      populateActionDBMultipleStages(3, db, hostName, requestId + i, stageId);
+    }
+
+    int commandsInProgressCount = db.getCommandsInProgressCount();
+    List<Stage> stages = db.getStagesInProgress();
+    assertEquals(6000, commandsInProgressCount);
+    assertEquals(3000, stages.size());
+  }
+
+
   @Test
   @Test
   public void testPersistActions() throws AmbariException {
   public void testPersistActions() throws AmbariException {
     populateActionDB(db, hostName, requestId, stageId);
     populateActionDB(db, hostName, requestId, stageId);
@@ -539,6 +632,20 @@ public class TestActionDBAccessorImpl {
     db.persistActions(request);
     db.persistActions(request);
   }
   }
 
 
+  private void populateActionDBMultipleStages(int numberOfStages,
+      ActionDBAccessor db, String hostname, long requestId, long stageId)
+      throws AmbariException {
+
+    List<Stage> stages = new ArrayList<Stage>();
+    for (int i = 0; i < numberOfStages; i++) {
+      Stage stage = createStubStage(hostname, requestId, stageId + i);
+      stages.add(stage);
+    }
+
+    Request request = new Request(stages, clusters);
+    db.persistActions(request);
+  }
+
   private Stage createStubStage(String hostname, long requestId, long stageId) {
   private Stage createStubStage(String hostname, long requestId, long stageId) {
     Stage s = new Stage(requestId, "/a/b", "cluster1", 1L, "action db accessor test",
     Stage s = new Stage(requestId, "/a/b", "cluster1", 1L, "action db accessor test",
       "clusterHostInfo", "commandParamsStage", "hostParamsStage");
       "clusterHostInfo", "commandParamsStage", "hostParamsStage");

+ 50 - 24
ambari-server/src/test/java/org/apache/ambari/server/actionmanager/TestActionScheduler.java

@@ -17,22 +17,37 @@
  */
  */
 package org.apache.ambari.server.actionmanager;
 package org.apache.ambari.server.actionmanager;
 
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyCollectionOf;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.*;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
 import java.lang.reflect.Type;
 import java.lang.reflect.Type;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.TreeMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 
 
-import com.google.common.reflect.TypeToken;
-import com.google.inject.AbstractModule;
-import com.google.inject.Guice;
-import com.google.inject.Injector;
-import com.google.inject.persist.UnitOfWork;
 import junit.framework.Assert;
 import junit.framework.Assert;
+
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.AmbariException;
 import org.apache.ambari.server.Role;
 import org.apache.ambari.server.Role;
 import org.apache.ambari.server.RoleCommand;
 import org.apache.ambari.server.RoleCommand;
@@ -72,6 +87,12 @@ import org.mockito.stubbing.Answer;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
+import com.google.common.reflect.TypeToken;
+import com.google.inject.AbstractModule;
+import com.google.inject.Guice;
+import com.google.inject.Injector;
+import com.google.inject.persist.UnitOfWork;
+
 public class TestActionScheduler {
 public class TestActionScheduler {
 
 
   private static final Logger log = LoggerFactory.getLogger(TestActionScheduler.class);
   private static final Logger log = LoggerFactory.getLogger(TestActionScheduler.class);
@@ -135,6 +156,8 @@ public class TestActionScheduler {
     Stage s = StageUtils.getATestStage(1, 977, hostname, CLUSTER_HOST_INFO,
     Stage s = StageUtils.getATestStage(1, 977, hostname, CLUSTER_HOST_INFO,
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
     stages.add(s);
     stages.add(s);
+
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     Request request = mock(Request.class);
     Request request = mock(Request.class);
@@ -227,6 +250,7 @@ public class TestActionScheduler {
     stages.add(s);
     stages.add(s);
 
 
     ActionDBAccessor db = mock(ActionDBAccessor.class);
     ActionDBAccessor db = mock(ActionDBAccessor.class);
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
@@ -309,7 +333,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -395,7 +419,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     doAnswer(new Answer() {
     doAnswer(new Answer() {
@@ -528,7 +552,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -621,7 +645,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -706,7 +730,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -837,7 +861,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     Properties properties = new Properties();
     Properties properties = new Properties();
@@ -928,7 +952,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     Properties properties = new Properties();
     Properties properties = new Properties();
@@ -1004,7 +1028,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     Properties properties = new Properties();
     Properties properties = new Properties();
@@ -1066,7 +1090,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -1247,7 +1271,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -1424,7 +1448,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -1657,6 +1681,8 @@ public class TestActionScheduler {
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
     Stage s2 = StageUtils.getATestStage(requestId2, stageId, hostname, CLUSTER_HOST_INFO_UPDATED,
     Stage s2 = StageUtils.getATestStage(requestId2, stageId, hostname, CLUSTER_HOST_INFO_UPDATED,
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
       "{\"host_param\":\"param_value\"}", "{\"stage_param\":\"param_value\"}");
+
+    when(db.getCommandsInProgressCount()).thenReturn(1);
     when(db.getStagesInProgress()).thenReturn(Collections.singletonList(s1));
     when(db.getStagesInProgress()).thenReturn(Collections.singletonList(s1));
 
 
     //Keep large number of attempts so that the task is not expired finally
     //Keep large number of attempts so that the task is not expired finally
@@ -1672,7 +1698,7 @@ public class TestActionScheduler {
 
 
     assertEquals(clusterHostInfo1, ((ExecutionCommand) (ac.get(0))).getClusterHostInfo());
     assertEquals(clusterHostInfo1, ((ExecutionCommand) (ac.get(0))).getClusterHostInfo());
 
 
-
+    when(db.getCommandsInProgressCount()).thenReturn(1);
     when(db.getStagesInProgress()).thenReturn(Collections.singletonList(s2));
     when(db.getStagesInProgress()).thenReturn(Collections.singletonList(s2));
 
 
     //Verify that ActionSheduler does not return cached value of cluster host info for new requestId
     //Verify that ActionSheduler does not return cached value of cluster host info for new requestId
@@ -1737,7 +1763,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     ActionScheduler scheduler = new ActionScheduler(100, 50000, db, aq, fsm, 3,
     ActionScheduler scheduler = new ActionScheduler(100, 50000, db, aq, fsm, 3,
@@ -1818,7 +1844,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
     doAnswer(new Answer() {
     doAnswer(new Answer() {
       @Override
       @Override
@@ -1913,7 +1939,7 @@ public class TestActionScheduler {
     Request request = mock(Request.class);
     Request request = mock(Request.class);
     when(request.isExclusive()).thenReturn(false);
     when(request.isExclusive()).thenReturn(false);
     when(db.getRequest(anyLong())).thenReturn(request);
     when(db.getRequest(anyLong())).thenReturn(request);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stages.size());
     when(db.getStagesInProgress()).thenReturn(stages);
     when(db.getStagesInProgress()).thenReturn(stages);
 
 
     List<HostRoleCommand> requestTasks = new ArrayList<HostRoleCommand>();
     List<HostRoleCommand> requestTasks = new ArrayList<HostRoleCommand>();
@@ -2078,7 +2104,7 @@ public class TestActionScheduler {
     when(host3.getHostName()).thenReturn(hostname);
     when(host3.getHostName()).thenReturn(hostname);
 
 
     ActionDBAccessor db = mock(ActionDBAccessor.class);
     ActionDBAccessor db = mock(ActionDBAccessor.class);
-
+    when(db.getCommandsInProgressCount()).thenReturn(stagesInProgress.size());
     when(db.getStagesInProgress()).thenReturn(stagesInProgress);
     when(db.getStagesInProgress()).thenReturn(stagesInProgress);
 
 
     List<HostRoleCommand> requestTasks = new ArrayList<HostRoleCommand>();
     List<HostRoleCommand> requestTasks = new ArrayList<HostRoleCommand>();