Przeglądaj źródła

MAPREDUCE-4607. Race condition in ReduceTask completion can result in Task being incorrectly failed. Contributed by Bikas Saha.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1383422 13f79535-47bb-0310-9956-ffa450edef68
Thomas White 12 lat temu
rodzic
commit
3b46295c28

+ 33 - 9
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/job/impl/TaskAttemptImpl.java

@@ -71,6 +71,7 @@ import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptId;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptReport;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptReport;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptState;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptState;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskId;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskId;
+import org.apache.hadoop.mapreduce.v2.api.records.TaskState;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskType;
 import org.apache.hadoop.mapreduce.v2.api.records.TaskType;
 import org.apache.hadoop.mapreduce.v2.app.AppContext;
 import org.apache.hadoop.mapreduce.v2.app.AppContext;
 import org.apache.hadoop.mapreduce.v2.app.TaskAttemptListener;
 import org.apache.hadoop.mapreduce.v2.app.TaskAttemptListener;
@@ -86,6 +87,7 @@ import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEventType;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEventType;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptKillEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptKillEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent;
+import org.apache.hadoop.mapreduce.v2.app.job.event.TaskEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent.TaskAttemptStatus;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent.TaskAttemptStatus;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskEventType;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskEventType;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskTAttemptEvent;
 import org.apache.hadoop.mapreduce.v2.app.job.event.TaskTAttemptEvent;
@@ -120,6 +122,7 @@ import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.hadoop.yarn.factories.RecordFactory;
 import org.apache.hadoop.yarn.factories.RecordFactory;
 import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider;
 import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider;
 import org.apache.hadoop.yarn.state.InvalidStateTransitonException;
 import org.apache.hadoop.yarn.state.InvalidStateTransitonException;
+import org.apache.hadoop.yarn.state.MultipleArcTransition;
 import org.apache.hadoop.yarn.state.SingleArcTransition;
 import org.apache.hadoop.yarn.state.SingleArcTransition;
 import org.apache.hadoop.yarn.state.StateMachine;
 import org.apache.hadoop.yarn.state.StateMachine;
 import org.apache.hadoop.yarn.state.StateMachineFactory;
 import org.apache.hadoop.yarn.state.StateMachineFactory;
@@ -128,6 +131,8 @@ import org.apache.hadoop.yarn.util.BuilderUtils;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.apache.hadoop.yarn.util.RackResolver;
 import org.apache.hadoop.yarn.util.RackResolver;
 
 
+import com.google.common.base.Preconditions;
+
 /**
 /**
  * Implementation of TaskAttempt interface.
  * Implementation of TaskAttempt interface.
  */
  */
@@ -404,10 +409,10 @@ public abstract class TaskAttemptImpl implements
          TaskAttemptState.FAILED,
          TaskAttemptState.FAILED,
          TaskAttemptEventType.TA_TOO_MANY_FETCH_FAILURE,
          TaskAttemptEventType.TA_TOO_MANY_FETCH_FAILURE,
          new TooManyFetchFailureTransition())
          new TooManyFetchFailureTransition())
-     .addTransition(
-         TaskAttemptState.SUCCEEDED, TaskAttemptState.KILLED,
-         TaskAttemptEventType.TA_KILL,
-         new KilledAfterSuccessTransition())
+      .addTransition(TaskAttemptState.SUCCEEDED,
+          EnumSet.of(TaskAttemptState.SUCCEEDED, TaskAttemptState.KILLED),
+          TaskAttemptEventType.TA_KILL, 
+          new KilledAfterSuccessTransition())
      .addTransition(
      .addTransition(
          TaskAttemptState.SUCCEEDED, TaskAttemptState.SUCCEEDED,
          TaskAttemptState.SUCCEEDED, TaskAttemptState.SUCCEEDED,
          TaskAttemptEventType.TA_DIAGNOSTICS_UPDATE,
          TaskAttemptEventType.TA_DIAGNOSTICS_UPDATE,
@@ -1483,6 +1488,9 @@ public abstract class TaskAttemptImpl implements
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     @Override
     @Override
     public void transition(TaskAttemptImpl taskAttempt, TaskAttemptEvent event) {
     public void transition(TaskAttemptImpl taskAttempt, TaskAttemptEvent event) {
+      // too many fetch failure can only happen for map tasks
+      Preconditions
+          .checkArgument(taskAttempt.getID().getTaskId().getTaskType() == TaskType.MAP);
       //add to diagnostic
       //add to diagnostic
       taskAttempt.addDiagnosticInfo("Too Many fetch failures.Failing the attempt");
       taskAttempt.addDiagnosticInfo("Too Many fetch failures.Failing the attempt");
       //set the finish time
       //set the finish time
@@ -1506,15 +1514,30 @@ public abstract class TaskAttemptImpl implements
   }
   }
   
   
   private static class KilledAfterSuccessTransition implements
   private static class KilledAfterSuccessTransition implements
-      SingleArcTransition<TaskAttemptImpl, TaskAttemptEvent> {
+      MultipleArcTransition<TaskAttemptImpl, TaskAttemptEvent, TaskAttemptState> {
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     @Override
     @Override
-    public void transition(TaskAttemptImpl taskAttempt, 
+    public TaskAttemptState transition(TaskAttemptImpl taskAttempt, 
         TaskAttemptEvent event) {
         TaskAttemptEvent event) {
-      TaskAttemptKillEvent msgEvent = (TaskAttemptKillEvent) event;
-      //add to diagnostic
-      taskAttempt.addDiagnosticInfo(msgEvent.getMessage());
+      if(taskAttempt.getID().getTaskId().getTaskType() == TaskType.REDUCE) {
+        // after a reduce task has succeeded, its outputs are in safe in HDFS.
+        // logically such a task should not be killed. we only come here when
+        // there is a race condition in the event queue. E.g. some logic sends
+        // a kill request to this attempt when the successful completion event
+        // for this task is already in the event queue. so the kill event will
+        // get executed immediately after the attempt is marked successful and 
+        // result in this transition being exercised.
+        // ignore this for reduce tasks
+        LOG.info("Ignoring killed event for successful reduce task attempt" +
+                  taskAttempt.getID().toString());
+        return TaskAttemptState.SUCCEEDED;
+      }
+      if(event instanceof TaskAttemptKillEvent) {
+        TaskAttemptKillEvent msgEvent = (TaskAttemptKillEvent) event;
+        //add to diagnostic
+        taskAttempt.addDiagnosticInfo(msgEvent.getMessage());
+      }
 
 
       // not setting a finish time since it was set on success
       // not setting a finish time since it was set on success
       assert (taskAttempt.getFinishTime() != 0);
       assert (taskAttempt.getFinishTime() != 0);
@@ -1528,6 +1551,7 @@ public abstract class TaskAttemptImpl implements
           .getTaskId().getJobId(), tauce));
           .getTaskId().getJobId(), tauce));
       taskAttempt.eventHandler.handle(new TaskTAttemptEvent(
       taskAttempt.eventHandler.handle(new TaskTAttemptEvent(
           taskAttempt.attemptId, TaskEventType.T_ATTEMPT_KILLED));
           taskAttempt.attemptId, TaskEventType.T_ATTEMPT_KILLED));
+      return TaskAttemptState.KILLED;
     }
     }
   }
   }
 
 

+ 37 - 33
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/job/impl/TaskImpl.java

@@ -191,12 +191,12 @@ public abstract class TaskImpl implements Task, EventHandler<TaskEvent> {
             TaskEventType.T_ADD_SPEC_ATTEMPT))
             TaskEventType.T_ADD_SPEC_ATTEMPT))
 
 
     // Transitions from SUCCEEDED state
     // Transitions from SUCCEEDED state
-    .addTransition(TaskState.SUCCEEDED, //only possible for map tasks
+    .addTransition(TaskState.SUCCEEDED,
         EnumSet.of(TaskState.SCHEDULED, TaskState.SUCCEEDED, TaskState.FAILED),
         EnumSet.of(TaskState.SCHEDULED, TaskState.SUCCEEDED, TaskState.FAILED),
-        TaskEventType.T_ATTEMPT_FAILED, new MapRetroactiveFailureTransition())
-    .addTransition(TaskState.SUCCEEDED, //only possible for map tasks
+        TaskEventType.T_ATTEMPT_FAILED, new RetroactiveFailureTransition())
+    .addTransition(TaskState.SUCCEEDED,
         EnumSet.of(TaskState.SCHEDULED, TaskState.SUCCEEDED),
         EnumSet.of(TaskState.SCHEDULED, TaskState.SUCCEEDED),
-        TaskEventType.T_ATTEMPT_KILLED, new MapRetroactiveKilledTransition())
+        TaskEventType.T_ATTEMPT_KILLED, new RetroactiveKilledTransition())
     // Ignore-able transitions.
     // Ignore-able transitions.
     .addTransition(
     .addTransition(
         TaskState.SUCCEEDED, TaskState.SUCCEEDED,
         TaskState.SUCCEEDED, TaskState.SUCCEEDED,
@@ -897,7 +897,7 @@ public abstract class TaskImpl implements Task, EventHandler<TaskEvent> {
     }
     }
   }
   }
 
 
-  private static class MapRetroactiveFailureTransition
+  private static class RetroactiveFailureTransition
       extends AttemptFailedTransition {
       extends AttemptFailedTransition {
 
 
     @Override
     @Override
@@ -911,8 +911,8 @@ public abstract class TaskImpl implements Task, EventHandler<TaskEvent> {
           return TaskState.SUCCEEDED;
           return TaskState.SUCCEEDED;
         }
         }
       }
       }
-      
-      //verify that this occurs only for map task
+
+      // a successful REDUCE task should not be overridden
       //TODO: consider moving it to MapTaskImpl
       //TODO: consider moving it to MapTaskImpl
       if (!TaskType.MAP.equals(task.getType())) {
       if (!TaskType.MAP.equals(task.getType())) {
         LOG.error("Unexpected event for REDUCE task " + event.getType());
         LOG.error("Unexpected event for REDUCE task " + event.getType());
@@ -938,42 +938,46 @@ public abstract class TaskImpl implements Task, EventHandler<TaskEvent> {
     }
     }
   }
   }
 
 
-  private static class MapRetroactiveKilledTransition implements
+  private static class RetroactiveKilledTransition implements
     MultipleArcTransition<TaskImpl, TaskEvent, TaskState> {
     MultipleArcTransition<TaskImpl, TaskEvent, TaskState> {
 
 
     @Override
     @Override
     public TaskState transition(TaskImpl task, TaskEvent event) {
     public TaskState transition(TaskImpl task, TaskEvent event) {
-      // verify that this occurs only for map task
+      TaskAttemptId attemptId = null;
+      if (event instanceof TaskTAttemptEvent) {
+        TaskTAttemptEvent castEvent = (TaskTAttemptEvent) event;
+        attemptId = castEvent.getTaskAttemptID(); 
+        if (task.getState() == TaskState.SUCCEEDED &&
+            !attemptId.equals(task.successfulAttempt)) {
+          // don't allow a different task attempt to override a previous
+          // succeeded state
+          return TaskState.SUCCEEDED;
+        }
+      }
+
+      // a successful REDUCE task should not be overridden
       // TODO: consider moving it to MapTaskImpl
       // TODO: consider moving it to MapTaskImpl
       if (!TaskType.MAP.equals(task.getType())) {
       if (!TaskType.MAP.equals(task.getType())) {
         LOG.error("Unexpected event for REDUCE task " + event.getType());
         LOG.error("Unexpected event for REDUCE task " + event.getType());
         task.internalError(event.getType());
         task.internalError(event.getType());
       }
       }
 
 
-      TaskTAttemptEvent attemptEvent = (TaskTAttemptEvent) event;
-      TaskAttemptId attemptId = attemptEvent.getTaskAttemptID();
-      if(task.successfulAttempt == attemptId) {
-        // successful attempt is now killed. reschedule
-        // tell the job about the rescheduling
-        unSucceed(task);
-        task.handleTaskAttemptCompletion(
-            attemptId, 
-            TaskAttemptCompletionEventStatus.KILLED);
-        task.eventHandler.handle(new JobMapTaskRescheduledEvent(task.taskId));
-        // typically we are here because this map task was run on a bad node and 
-        // we want to reschedule it on a different node.
-        // Depending on whether there are previous failed attempts or not this 
-        // can SCHEDULE or RESCHEDULE the container allocate request. If this
-        // SCHEDULE's then the dataLocal hosts of this taskAttempt will be used
-        // from the map splitInfo. So the bad node might be sent as a location 
-        // to the RM. But the RM would ignore that just like it would ignore 
-        // currently pending container requests affinitized to bad nodes.
-        task.addAndScheduleAttempt();
-        return TaskState.SCHEDULED;
-      } else {
-        // nothing to do
-        return TaskState.SUCCEEDED;
-      }
+      // successful attempt is now killed. reschedule
+      // tell the job about the rescheduling
+      unSucceed(task);
+      task.handleTaskAttemptCompletion(attemptId,
+          TaskAttemptCompletionEventStatus.KILLED);
+      task.eventHandler.handle(new JobMapTaskRescheduledEvent(task.taskId));
+      // typically we are here because this map task was run on a bad node and
+      // we want to reschedule it on a different node.
+      // Depending on whether there are previous failed attempts or not this
+      // can SCHEDULE or RESCHEDULE the container allocate request. If this
+      // SCHEDULE's then the dataLocal hosts of this taskAttempt will be used
+      // from the map splitInfo. So the bad node might be sent as a location
+      // to the RM. But the RM would ignore that just like it would ignore
+      // currently pending container requests affinitized to bad nodes.
+      task.addAndScheduleAttempt();
+      return TaskState.SCHEDULED;
     }
     }
   }
   }
 
 

+ 28 - 10
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/TestMRApp.java

@@ -180,7 +180,7 @@ public class TestMRApp {
   @Test
   @Test
   public void testUpdatedNodes() throws Exception {
   public void testUpdatedNodes() throws Exception {
     int runCount = 0;
     int runCount = 0;
-    MRApp app = new MRAppWithHistory(2, 1, false, this.getClass().getName(),
+    MRApp app = new MRAppWithHistory(2, 2, false, this.getClass().getName(),
         true, ++runCount);
         true, ++runCount);
     Configuration conf = new Configuration();
     Configuration conf = new Configuration();
     // after half of the map completion, reduce will start
     // after half of the map completion, reduce will start
@@ -189,7 +189,7 @@ public class TestMRApp {
     conf.setBoolean(MRJobConfig.JOB_UBERTASK_ENABLE, false);
     conf.setBoolean(MRJobConfig.JOB_UBERTASK_ENABLE, false);
     Job job = app.submit(conf);
     Job job = app.submit(conf);
     app.waitForState(job, JobState.RUNNING);
     app.waitForState(job, JobState.RUNNING);
-    Assert.assertEquals("Num tasks not correct", 3, job.getTasks().size());
+    Assert.assertEquals("Num tasks not correct", 4, job.getTasks().size());
     Iterator<Task> it = job.getTasks().values().iterator();
     Iterator<Task> it = job.getTasks().values().iterator();
     Task mapTask1 = it.next();
     Task mapTask1 = it.next();
     Task mapTask2 = it.next();
     Task mapTask2 = it.next();
@@ -272,18 +272,19 @@ public class TestMRApp {
 
 
     // rerun
     // rerun
     // in rerun the 1st map will be recovered from previous run
     // in rerun the 1st map will be recovered from previous run
-    app = new MRAppWithHistory(2, 1, false, this.getClass().getName(), false,
+    app = new MRAppWithHistory(2, 2, false, this.getClass().getName(), false,
         ++runCount);
         ++runCount);
     conf = new Configuration();
     conf = new Configuration();
     conf.setBoolean(MRJobConfig.MR_AM_JOB_RECOVERY_ENABLE, true);
     conf.setBoolean(MRJobConfig.MR_AM_JOB_RECOVERY_ENABLE, true);
     conf.setBoolean(MRJobConfig.JOB_UBERTASK_ENABLE, false);
     conf.setBoolean(MRJobConfig.JOB_UBERTASK_ENABLE, false);
     job = app.submit(conf);
     job = app.submit(conf);
     app.waitForState(job, JobState.RUNNING);
     app.waitForState(job, JobState.RUNNING);
-    Assert.assertEquals("No of tasks not correct", 3, job.getTasks().size());
+    Assert.assertEquals("No of tasks not correct", 4, job.getTasks().size());
     it = job.getTasks().values().iterator();
     it = job.getTasks().values().iterator();
     mapTask1 = it.next();
     mapTask1 = it.next();
     mapTask2 = it.next();
     mapTask2 = it.next();
-    Task reduceTask = it.next();
+    Task reduceTask1 = it.next();
+    Task reduceTask2 = it.next();
 
 
     // map 1 will be recovered, no need to send done
     // map 1 will be recovered, no need to send done
     app.waitForState(mapTask1, TaskState.SUCCEEDED);
     app.waitForState(mapTask1, TaskState.SUCCEEDED);
@@ -306,19 +307,36 @@ public class TestMRApp {
     Assert.assertEquals("Expecting 1 more completion events for success", 3,
     Assert.assertEquals("Expecting 1 more completion events for success", 3,
         events.length);
         events.length);
 
 
-    app.waitForState(reduceTask, TaskState.RUNNING);
-    TaskAttempt task3Attempt = reduceTask.getAttempts().values().iterator()
+    app.waitForState(reduceTask1, TaskState.RUNNING);
+    app.waitForState(reduceTask2, TaskState.RUNNING);
+
+    TaskAttempt task3Attempt = reduceTask1.getAttempts().values().iterator()
         .next();
         .next();
     app.getContext()
     app.getContext()
         .getEventHandler()
         .getEventHandler()
         .handle(
         .handle(
             new TaskAttemptEvent(task3Attempt.getID(),
             new TaskAttemptEvent(task3Attempt.getID(),
                 TaskAttemptEventType.TA_DONE));
                 TaskAttemptEventType.TA_DONE));
-    app.waitForState(reduceTask, TaskState.SUCCEEDED);
+    app.waitForState(reduceTask1, TaskState.SUCCEEDED);
+    app.getContext()
+    .getEventHandler()
+    .handle(
+        new TaskAttemptEvent(task3Attempt.getID(),
+            TaskAttemptEventType.TA_KILL));
+    app.waitForState(reduceTask1, TaskState.SUCCEEDED);
+    
+    TaskAttempt task4Attempt = reduceTask2.getAttempts().values().iterator()
+        .next();
+    app.getContext()
+        .getEventHandler()
+        .handle(
+            new TaskAttemptEvent(task4Attempt.getID(),
+                TaskAttemptEventType.TA_DONE));
+    app.waitForState(reduceTask2, TaskState.SUCCEEDED);    
 
 
     events = job.getTaskAttemptCompletionEvents(0, 100);
     events = job.getTaskAttemptCompletionEvents(0, 100);
-    Assert.assertEquals("Expecting 1 more completion events for success", 4,
-        events.length);
+    Assert.assertEquals("Expecting 2 more completion events for reduce success",
+        5, events.length);
 
 
     // job succeeds
     // job succeeds
     app.waitForState(job, JobState.SUCCEEDED);
     app.waitForState(job, JobState.SUCCEEDED);

+ 54 - 15
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/job/impl/TestTaskImpl.java

@@ -84,7 +84,6 @@ public class TestTaskImpl {
   private ApplicationId appId;
   private ApplicationId appId;
   private TaskSplitMetaInfo taskSplitMetaInfo;  
   private TaskSplitMetaInfo taskSplitMetaInfo;  
   private String[] dataLocations = new String[0]; 
   private String[] dataLocations = new String[0]; 
-  private final TaskType taskType = TaskType.MAP;
   private AppContext appContext;
   private AppContext appContext;
   
   
   private int startCount = 0;
   private int startCount = 0;
@@ -97,6 +96,7 @@ public class TestTaskImpl {
   private class MockTaskImpl extends TaskImpl {
   private class MockTaskImpl extends TaskImpl {
         
         
     private int taskAttemptCounter = 0;
     private int taskAttemptCounter = 0;
+    TaskType taskType;
 
 
     public MockTaskImpl(JobId jobId, int partition,
     public MockTaskImpl(JobId jobId, int partition,
         EventHandler eventHandler, Path remoteJobConfFile, JobConf conf,
         EventHandler eventHandler, Path remoteJobConfFile, JobConf conf,
@@ -104,11 +104,12 @@ public class TestTaskImpl {
         Token<JobTokenIdentifier> jobToken,
         Token<JobTokenIdentifier> jobToken,
         Credentials credentials, Clock clock,
         Credentials credentials, Clock clock,
         Map<TaskId, TaskInfo> completedTasksFromPreviousRun, int startCount,
         Map<TaskId, TaskInfo> completedTasksFromPreviousRun, int startCount,
-        MRAppMetrics metrics, AppContext appContext) {
+        MRAppMetrics metrics, AppContext appContext, TaskType taskType) {
       super(jobId, taskType , partition, eventHandler,
       super(jobId, taskType , partition, eventHandler,
           remoteJobConfFile, conf, taskAttemptListener, committer, 
           remoteJobConfFile, conf, taskAttemptListener, committer, 
           jobToken, credentials, clock,
           jobToken, credentials, clock,
           completedTasksFromPreviousRun, startCount, metrics, appContext);
           completedTasksFromPreviousRun, startCount, metrics, appContext);
+      this.taskType = taskType;
     }
     }
 
 
     @Override
     @Override
@@ -120,7 +121,7 @@ public class TestTaskImpl {
     protected TaskAttemptImpl createAttempt() {
     protected TaskAttemptImpl createAttempt() {
       MockTaskAttemptImpl attempt = new MockTaskAttemptImpl(getID(), ++taskAttemptCounter, 
       MockTaskAttemptImpl attempt = new MockTaskAttemptImpl(getID(), ++taskAttemptCounter, 
           eventHandler, taskAttemptListener, remoteJobConfFile, partition,
           eventHandler, taskAttemptListener, remoteJobConfFile, partition,
-          conf, committer, jobToken, credentials, clock, appContext);
+          conf, committer, jobToken, credentials, clock, appContext, taskType);
       taskAttempts.add(attempt);
       taskAttempts.add(attempt);
       return attempt;
       return attempt;
     }
     }
@@ -142,18 +143,20 @@ public class TestTaskImpl {
     private float progress = 0;
     private float progress = 0;
     private TaskAttemptState state = TaskAttemptState.NEW;
     private TaskAttemptState state = TaskAttemptState.NEW;
     private TaskAttemptId attemptId;
     private TaskAttemptId attemptId;
+    private TaskType taskType;
 
 
     public MockTaskAttemptImpl(TaskId taskId, int id, EventHandler eventHandler,
     public MockTaskAttemptImpl(TaskId taskId, int id, EventHandler eventHandler,
         TaskAttemptListener taskAttemptListener, Path jobFile, int partition,
         TaskAttemptListener taskAttemptListener, Path jobFile, int partition,
         JobConf conf, OutputCommitter committer,
         JobConf conf, OutputCommitter committer,
         Token<JobTokenIdentifier> jobToken,
         Token<JobTokenIdentifier> jobToken,
         Credentials credentials, Clock clock,
         Credentials credentials, Clock clock,
-        AppContext appContext) {
+        AppContext appContext, TaskType taskType) {
       super(taskId, id, eventHandler, taskAttemptListener, jobFile, partition, conf,
       super(taskId, id, eventHandler, taskAttemptListener, jobFile, partition, conf,
           dataLocations, committer, jobToken, credentials, clock, appContext);
           dataLocations, committer, jobToken, credentials, clock, appContext);
       attemptId = Records.newRecord(TaskAttemptId.class);
       attemptId = Records.newRecord(TaskAttemptId.class);
       attemptId.setId(id);
       attemptId.setId(id);
       attemptId.setTaskId(taskId);
       attemptId.setTaskId(taskId);
+      this.taskType = taskType;
     }
     }
 
 
     public TaskAttemptId getAttemptId() {
     public TaskAttemptId getAttemptId() {
@@ -162,7 +165,7 @@ public class TestTaskImpl {
     
     
     @Override
     @Override
     protected Task createRemoteTask() {
     protected Task createRemoteTask() {
-      return new MockTask();
+      return new MockTask(taskType);
     }    
     }    
     
     
     public float getProgress() {
     public float getProgress() {
@@ -185,6 +188,11 @@ public class TestTaskImpl {
   
   
   private class MockTask extends Task {
   private class MockTask extends Task {
 
 
+    private TaskType taskType;
+    MockTask(TaskType taskType) {
+      this.taskType = taskType;
+    }
+    
     @Override
     @Override
     public void run(JobConf job, TaskUmbilicalProtocol umbilical)
     public void run(JobConf job, TaskUmbilicalProtocol umbilical)
         throws IOException, ClassNotFoundException, InterruptedException {
         throws IOException, ClassNotFoundException, InterruptedException {
@@ -193,7 +201,7 @@ public class TestTaskImpl {
 
 
     @Override
     @Override
     public boolean isMapTask() {
     public boolean isMapTask() {
-      return true;
+      return (taskType == TaskType.MAP);
     }    
     }    
     
     
   }
   }
@@ -227,14 +235,15 @@ public class TestTaskImpl {
     taskSplitMetaInfo = mock(TaskSplitMetaInfo.class);
     taskSplitMetaInfo = mock(TaskSplitMetaInfo.class);
     when(taskSplitMetaInfo.getLocations()).thenReturn(dataLocations); 
     when(taskSplitMetaInfo.getLocations()).thenReturn(dataLocations); 
     
     
-    taskAttempts = new ArrayList<MockTaskAttemptImpl>();
-    
-    mockTask = new MockTaskImpl(jobId, partition, dispatcher.getEventHandler(),
+    taskAttempts = new ArrayList<MockTaskAttemptImpl>();    
+  }
+  
+  private MockTaskImpl createMockTask(TaskType taskType) {
+    return new MockTaskImpl(jobId, partition, dispatcher.getEventHandler(),
         remoteJobConfFile, conf, taskAttemptListener, committer, jobToken,
         remoteJobConfFile, conf, taskAttemptListener, committer, jobToken,
         credentials, clock,
         credentials, clock,
         completedTasksFromPreviousRun, startCount,
         completedTasksFromPreviousRun, startCount,
-        metrics, appContext);        
-    
+        metrics, appContext, taskType);
   }
   }
 
 
   @After 
   @After 
@@ -342,6 +351,7 @@ public class TestTaskImpl {
   @Test
   @Test
   public void testInit() {
   public void testInit() {
     LOG.info("--- START: testInit ---");
     LOG.info("--- START: testInit ---");
+    mockTask = createMockTask(TaskType.MAP);        
     assertTaskNewState();
     assertTaskNewState();
     assert(taskAttempts.size() == 0);
     assert(taskAttempts.size() == 0);
   }
   }
@@ -352,6 +362,7 @@ public class TestTaskImpl {
    */
    */
   public void testScheduleTask() {
   public void testScheduleTask() {
     LOG.info("--- START: testScheduleTask ---");
     LOG.info("--- START: testScheduleTask ---");
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
   }
   }
@@ -362,6 +373,7 @@ public class TestTaskImpl {
    */
    */
   public void testKillScheduledTask() {
   public void testKillScheduledTask() {
     LOG.info("--- START: testKillScheduledTask ---");
     LOG.info("--- START: testKillScheduledTask ---");
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     killTask(taskId);
     killTask(taskId);
@@ -374,6 +386,7 @@ public class TestTaskImpl {
    */
    */
   public void testKillScheduledTaskAttempt() {
   public void testKillScheduledTaskAttempt() {
     LOG.info("--- START: testKillScheduledTaskAttempt ---");
     LOG.info("--- START: testKillScheduledTaskAttempt ---");
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     killScheduledTaskAttempt(getLastAttempt().getAttemptId());
     killScheduledTaskAttempt(getLastAttempt().getAttemptId());
@@ -386,6 +399,7 @@ public class TestTaskImpl {
    */
    */
   public void testLaunchTaskAttempt() {
   public void testLaunchTaskAttempt() {
     LOG.info("--- START: testLaunchTaskAttempt ---");
     LOG.info("--- START: testLaunchTaskAttempt ---");
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     launchTaskAttempt(getLastAttempt().getAttemptId());
     launchTaskAttempt(getLastAttempt().getAttemptId());
@@ -398,6 +412,7 @@ public class TestTaskImpl {
    */
    */
   public void testKillRunningTaskAttempt() {
   public void testKillRunningTaskAttempt() {
     LOG.info("--- START: testKillRunningTaskAttempt ---");
     LOG.info("--- START: testKillRunningTaskAttempt ---");
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     launchTaskAttempt(getLastAttempt().getAttemptId());
     launchTaskAttempt(getLastAttempt().getAttemptId());
@@ -407,6 +422,7 @@ public class TestTaskImpl {
   @Test 
   @Test 
   public void testTaskProgress() {
   public void testTaskProgress() {
     LOG.info("--- START: testTaskProgress ---");
     LOG.info("--- START: testTaskProgress ---");
+    mockTask = createMockTask(TaskType.MAP);        
         
         
     // launch task
     // launch task
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
@@ -444,6 +460,7 @@ public class TestTaskImpl {
   
   
   @Test
   @Test
   public void testFailureDuringTaskAttemptCommit() {
   public void testFailureDuringTaskAttemptCommit() {
+    mockTask = createMockTask(TaskType.MAP);        
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     launchTaskAttempt(getLastAttempt().getAttemptId());
     launchTaskAttempt(getLastAttempt().getAttemptId());
@@ -469,8 +486,7 @@ public class TestTaskImpl {
     assertTaskSucceededState();
     assertTaskSucceededState();
   }
   }
   
   
-  @Test
-  public void testSpeculativeTaskAttemptSucceedsEvenIfFirstFails() {
+  private void runSpeculativeTaskAttemptSucceedsEvenIfFirstFails(TaskEventType failEvent) {
     TaskId taskId = getNewTaskID();
     TaskId taskId = getNewTaskID();
     scheduleTaskAttempt(taskId);
     scheduleTaskAttempt(taskId);
     launchTaskAttempt(getLastAttempt().getAttemptId());
     launchTaskAttempt(getLastAttempt().getAttemptId());
@@ -489,11 +505,34 @@ public class TestTaskImpl {
     
     
     // Now fail the first task attempt, after the second has succeeded
     // Now fail the first task attempt, after the second has succeeded
     mockTask.handle(new TaskTAttemptEvent(taskAttempts.get(0).getAttemptId(), 
     mockTask.handle(new TaskTAttemptEvent(taskAttempts.get(0).getAttemptId(), 
-        TaskEventType.T_ATTEMPT_FAILED));
+        failEvent));
     
     
     // The task should still be in the succeeded state
     // The task should still be in the succeeded state
     assertTaskSucceededState();
     assertTaskSucceededState();
-    
+  }
+  
+  @Test
+  public void testMapSpeculativeTaskAttemptSucceedsEvenIfFirstFails() {
+    mockTask = createMockTask(TaskType.MAP);        
+    runSpeculativeTaskAttemptSucceedsEvenIfFirstFails(TaskEventType.T_ATTEMPT_FAILED);
+  }
+
+  @Test
+  public void testReduceSpeculativeTaskAttemptSucceedsEvenIfFirstFails() {
+    mockTask = createMockTask(TaskType.REDUCE);        
+    runSpeculativeTaskAttemptSucceedsEvenIfFirstFails(TaskEventType.T_ATTEMPT_FAILED);
+  }
+  
+  @Test
+  public void testMapSpeculativeTaskAttemptSucceedsEvenIfFirstIsKilled() {
+    mockTask = createMockTask(TaskType.MAP);        
+    runSpeculativeTaskAttemptSucceedsEvenIfFirstFails(TaskEventType.T_ATTEMPT_KILLED);
+  }
+
+  @Test
+  public void testReduceSpeculativeTaskAttemptSucceedsEvenIfFirstIsKilled() {
+    mockTask = createMockTask(TaskType.REDUCE);        
+    runSpeculativeTaskAttemptSucceedsEvenIfFirstFails(TaskEventType.T_ATTEMPT_KILLED);
   }
   }
 
 
 }
 }