瀏覽代碼

Merge -c 1493445 from trunk to branch-2 to fix MAPREDUCE-5192. Allow for alternate resolutions of TaskCompletionEvents. Contributed by Chris Douglas.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/branches/branch-2@1493446 13f79535-47bb-0310-9956-ffa450edef68
Arun Murthy 12 年之前
父節點
當前提交
a00e4a926d

+ 3 - 0
hadoop-mapreduce-project/CHANGES.txt

@@ -172,6 +172,9 @@ Release 2.1.0-beta - UNRELEASED
     MAPREDUCE-5199. Removing ApplicationTokens file as it is no longer needed.
     (Daryn Sharp via vinodkv)
 
+    MAPREDUCE-5192. Allow for alternate resolutions of TaskCompletionEvents.
+    (cdouglas via acmurthy)
+
   OPTIMIZATIONS
 
     MAPREDUCE-4974. Optimising the LineRecordReader initialize() method 

+ 8 - 1
hadoop-mapreduce-project/dev-support/findbugs-exclude.xml

@@ -271,12 +271,19 @@
        <Class name="org.apache.hadoop.mapreduce.task.reduce.MergeManagerImpl" />
        <Bug pattern="SC_START_IN_CTOR" />
      </Match>
+    <!--
+     This class is unlikely to get subclassed, so ignore
+    -->
+    <Match>
+     <Class name="org.apache.hadoop.mapreduce.task.reduce.ShuffleSchedulerImpl" />
+     <Bug pattern="SC_START_IN_CTOR" />
+    </Match>
 
     <!--
       Do not bother if equals is not implemented. We will not need it here
     -->
      <Match>
-      <Class name="org.apache.hadoop.mapreduce.task.reduce.ShuffleScheduler$Penalty" />
+      <Class name="org.apache.hadoop.mapreduce.task.reduce.ShuffleSchedulerImpl$Penalty" />
       <Bug pattern="EQ_COMPARETO_USE_OBJECT_EQUALS" />
      </Match>
 

+ 10 - 54
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcher.java

@@ -18,7 +18,6 @@
 package org.apache.hadoop.mapreduce.task.reduce;
 
 import java.io.IOException;
-import java.net.URI;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -37,11 +36,9 @@ class EventFetcher<K,V> extends Thread {
   private final TaskUmbilicalProtocol umbilical;
   private final ShuffleScheduler<K,V> scheduler;
   private int fromEventIdx = 0;
-  private int maxEventsToFetch;
-  private ExceptionReporter exceptionReporter = null;
+  private final int maxEventsToFetch;
+  private final ExceptionReporter exceptionReporter;
   
-  private int maxMapRuntime = 0;
-
   private volatile boolean stopped = false;
   
   public EventFetcher(TaskAttemptID reduce,
@@ -113,7 +110,8 @@ class EventFetcher<K,V> extends Thread {
    * from a given event ID.
    * @throws IOException
    */  
-  protected int getMapCompletionEvents() throws IOException {
+  protected int getMapCompletionEvents()
+      throws IOException, InterruptedException {
     
     int numNewMaps = 0;
     TaskCompletionEvent events[] = null;
@@ -129,14 +127,7 @@ class EventFetcher<K,V> extends Thread {
       LOG.debug("Got " + events.length + " map completion events from " +
                fromEventIdx);
 
-      // Check if the reset is required.
-      // Since there is no ordering of the task completion events at the
-      // reducer, the only option to sync with the new jobtracker is to reset
-      // the events index
-      if (update.shouldReset()) {
-        fromEventIdx = 0;
-        scheduler.resetKnownMaps();
-      }
+      assert !update.shouldReset() : "Unexpected legacy state";
 
       // Update the last seen event ID
       fromEventIdx += events.length;
@@ -148,49 +139,14 @@ class EventFetcher<K,V> extends Thread {
       // 3. Remove TIPFAILED maps from neededOutputs since we don't need their
       //    outputs at all.
       for (TaskCompletionEvent event : events) {
-        switch (event.getTaskStatus()) {
-        case SUCCEEDED:
-          URI u = getBaseURI(event.getTaskTrackerHttp());
-          scheduler.addKnownMapOutput(u.getHost() + ":" + u.getPort(),
-              u.toString(),
-              event.getTaskAttemptId());
-          numNewMaps ++;
-          int duration = event.getTaskRunTime();
-          if (duration > maxMapRuntime) {
-            maxMapRuntime = duration;
-            scheduler.informMaxMapRunTime(maxMapRuntime);
-          }
-          break;
-        case FAILED:
-        case KILLED:
-        case OBSOLETE:
-          scheduler.obsoleteMapOutput(event.getTaskAttemptId());
-          LOG.info("Ignoring obsolete output of " + event.getTaskStatus() + 
-              " map-task: '" + event.getTaskAttemptId() + "'");
-          break;
-        case TIPFAILED:
-          scheduler.tipFailed(event.getTaskAttemptId().getTaskID());
-          LOG.info("Ignoring output of failed map TIP: '" +  
-              event.getTaskAttemptId() + "'");
-          break;
+        scheduler.resolve(event);
+        if (TaskCompletionEvent.Status.SUCCEEDED == event.getTaskStatus()) {
+          ++numNewMaps;
         }
       }
     } while (events.length == maxEventsToFetch);
 
     return numNewMaps;
   }
-  
-  private URI getBaseURI(String url) {
-    StringBuffer baseUrl = new StringBuffer(url);
-    if (!url.endsWith("/")) {
-      baseUrl.append("/");
-    }
-    baseUrl.append("mapOutput?job=");
-    baseUrl.append(reduce.getJobID());
-    baseUrl.append("&reduce=");
-    baseUrl.append(reduce.getTaskID().getId());
-    baseUrl.append("&map=");
-    URI u = URI.create(baseUrl.toString());
-    return u;
-  }
-}
+
+}

+ 2 - 2
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java

@@ -72,7 +72,7 @@ class Fetcher<K,V> extends Thread {
   private final Counters.Counter wrongMapErrs;
   private final Counters.Counter wrongReduceErrs;
   private final MergeManager<K,V> merger;
-  private final ShuffleScheduler<K,V> scheduler;
+  private final ShuffleSchedulerImpl<K,V> scheduler;
   private final ShuffleClientMetrics metrics;
   private final ExceptionReporter exceptionReporter;
   private final int id;
@@ -90,7 +90,7 @@ class Fetcher<K,V> extends Thread {
   private static SSLFactory sslFactory;
 
   public Fetcher(JobConf job, TaskAttemptID reduceId, 
-                 ShuffleScheduler<K,V> scheduler, MergeManager<K,V> merger,
+                 ShuffleSchedulerImpl<K,V> scheduler, MergeManager<K,V> merger,
                  Reporter reporter, ShuffleClientMetrics metrics,
                  ExceptionReporter exceptionReporter, SecretKey shuffleKey) {
     this.reporter = reporter;

+ 3 - 3
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Shuffle.java

@@ -49,7 +49,7 @@ public class Shuffle<K, V> implements ShuffleConsumerPlugin<K, V>, ExceptionRepo
   private ShuffleClientMetrics metrics;
   private TaskUmbilicalProtocol umbilical;
   
-  private ShuffleScheduler<K,V> scheduler;
+  private ShuffleSchedulerImpl<K,V> scheduler;
   private MergeManager<K, V> merger;
   private Throwable throwable = null;
   private String throwingThreadName = null;
@@ -70,8 +70,8 @@ public class Shuffle<K, V> implements ShuffleConsumerPlugin<K, V>, ExceptionRepo
     this.taskStatus = context.getStatus();
     this.reduceTask = context.getReduceTask();
     
-    scheduler = new ShuffleScheduler<K,V>(jobConf, taskStatus, this,
-        copyPhase, context.getShuffledMapsCounter(),
+    scheduler = new ShuffleSchedulerImpl<K, V>(jobConf, taskStatus, reduceId,
+        this, copyPhase, context.getShuffledMapsCounter(),
         context.getReduceShuffleBytes(), context.getFailedShuffleCounter());
     merger = createMergeManager(context);
   }

+ 12 - 414
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/ShuffleScheduler.java

@@ -18,432 +18,30 @@
 package org.apache.hadoop.mapreduce.task.reduce;
 
 import java.io.IOException;
-import java.text.DecimalFormat;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
-import java.util.concurrent.DelayQueue;
-import java.util.concurrent.Delayed;
-import java.util.concurrent.TimeUnit;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.TaskStatus;
-import org.apache.hadoop.mapreduce.MRJobConfig;
-import org.apache.hadoop.mapreduce.TaskAttemptID;
-import org.apache.hadoop.mapreduce.TaskID;
-import org.apache.hadoop.mapreduce.task.reduce.MapHost.State;
-import org.apache.hadoop.util.Progress;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.mapred.TaskCompletionEvent;
 
-class ShuffleScheduler<K,V> {
-  static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
-    protected Long initialValue() {
-      return 0L;
-    }
-  };
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+public interface ShuffleScheduler<K,V> {
 
-  private static final Log LOG = LogFactory.getLog(ShuffleScheduler.class);
-  private static final int MAX_MAPS_AT_ONCE = 20;
-  private static final long INITIAL_PENALTY = 10000;
-  private static final float PENALTY_GROWTH_RATE = 1.3f;
-  private final static int REPORT_FAILURE_LIMIT = 10;
-
-  private final boolean[] finishedMaps;
-  private final int totalMaps;
-  private int remainingMaps;
-  private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
-  private Set<MapHost> pendingHosts = new HashSet<MapHost>();
-  private Set<TaskAttemptID> obsoleteMaps = new HashSet<TaskAttemptID>();
-  
-  private final Random random = new Random(System.currentTimeMillis());
-  private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
-  private final Referee referee = new Referee();
-  private final Map<TaskAttemptID,IntWritable> failureCounts =
-    new HashMap<TaskAttemptID,IntWritable>();
-  private final Map<String,IntWritable> hostFailures = 
-    new HashMap<String,IntWritable>();
-  private final TaskStatus status;
-  private final ExceptionReporter reporter;
-  private final int abortFailureLimit;
-  private final Progress progress;
-  private final Counters.Counter shuffledMapsCounter;
-  private final Counters.Counter reduceShuffleBytes;
-  private final Counters.Counter failedShuffleCounter;
-  
-  private final long startTime;
-  private long lastProgressTime;
-  
-  private int maxMapRuntime = 0;
-  private int maxFailedUniqueFetches = 5;
-  private int maxFetchFailuresBeforeReporting;
-  
-  private long totalBytesShuffledTillNow = 0;
-  private DecimalFormat  mbpsFormat = new DecimalFormat("0.00");
-
-  private boolean reportReadErrorImmediately = true;
-  private long maxDelay = MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY;
-  
-  public ShuffleScheduler(JobConf job, TaskStatus status,
-                          ExceptionReporter reporter,
-                          Progress progress,
-                          Counters.Counter shuffledMapsCounter,
-                          Counters.Counter reduceShuffleBytes,
-                          Counters.Counter failedShuffleCounter) {
-    totalMaps = job.getNumMapTasks();
-    abortFailureLimit = Math.max(30, totalMaps / 10);
-    remainingMaps = totalMaps;
-    finishedMaps = new boolean[remainingMaps];
-    this.reporter = reporter;
-    this.status = status;
-    this.progress = progress;
-    this.shuffledMapsCounter = shuffledMapsCounter;
-    this.reduceShuffleBytes = reduceShuffleBytes;
-    this.failedShuffleCounter = failedShuffleCounter;
-    this.startTime = System.currentTimeMillis();
-    lastProgressTime = startTime;
-    referee.start();
-    this.maxFailedUniqueFetches = Math.min(totalMaps,
-        this.maxFailedUniqueFetches);
-    this.maxFetchFailuresBeforeReporting = job.getInt(
-        MRJobConfig.SHUFFLE_FETCH_FAILURES, REPORT_FAILURE_LIMIT);
-    this.reportReadErrorImmediately = job.getBoolean(
-        MRJobConfig.SHUFFLE_NOTIFY_READERROR, true);
-    
-    this.maxDelay = job.getLong(MRJobConfig.MAX_SHUFFLE_FETCH_RETRY_DELAY, 
-        MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY);
-  }
-
-  public synchronized void copySucceeded(TaskAttemptID mapId, 
-                                         MapHost host,
-                                         long bytes,
-                                         long millis,
-                                         MapOutput<K,V> output
-                                         ) throws IOException {
-    failureCounts.remove(mapId);
-    hostFailures.remove(host.getHostName());
-    int mapIndex = mapId.getTaskID().getId();
-    
-    if (!finishedMaps[mapIndex]) {
-      output.commit();
-      finishedMaps[mapIndex] = true;
-      shuffledMapsCounter.increment(1);
-      if (--remainingMaps == 0) {
-        notifyAll();
-      }
-
-      // update the status
-      totalBytesShuffledTillNow += bytes;
-      updateStatus();
-      reduceShuffleBytes.increment(bytes);
-      lastProgressTime = System.currentTimeMillis();
-      LOG.debug("map " + mapId + " done " + status.getStateString());
-    }
-  }
-  
-  private void updateStatus() {
-    float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
-    int mapsDone = totalMaps - remainingMaps;
-    long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;
-
-    float transferRate = mbs / secsSinceStart;
-    progress.set((float) mapsDone / totalMaps);
-    String statusString = mapsDone + " / " + totalMaps + " copied.";
-    status.setStateString(statusString);
-
-    progress.setStatus("copy(" + mapsDone + " of " + totalMaps + " at "
-        + mbpsFormat.format(transferRate) + " MB/s)");
-  }
-
-  public synchronized void copyFailed(TaskAttemptID mapId, MapHost host,
-                                      boolean readError, boolean connectExcpt) {
-    host.penalize();
-    int failures = 1;
-    if (failureCounts.containsKey(mapId)) {
-      IntWritable x = failureCounts.get(mapId);
-      x.set(x.get() + 1);
-      failures = x.get();
-    } else {
-      failureCounts.put(mapId, new IntWritable(1));      
-    }
-    String hostname = host.getHostName();
-    if (hostFailures.containsKey(hostname)) {
-      IntWritable x = hostFailures.get(hostname);
-      x.set(x.get() + 1);
-    } else {
-      hostFailures.put(hostname, new IntWritable(1));
-    }
-    if (failures >= abortFailureLimit) {
-      try {
-        throw new IOException(failures + " failures downloading " + mapId);
-      } catch (IOException ie) {
-        reporter.reportException(ie);
-      }
-    }
-    
-    checkAndInformJobTracker(failures, mapId, readError, connectExcpt);
-
-    checkReducerHealth();
-    
-    long delay = (long) (INITIAL_PENALTY *
-        Math.pow(PENALTY_GROWTH_RATE, failures));
-    if (delay > maxDelay) {
-      delay = maxDelay;
-    }
-    
-    penalties.add(new Penalty(host, delay));
-    
-    failedShuffleCounter.increment(1);
-  }
-  
-  // Notify the JobTracker  
-  // after every read error, if 'reportReadErrorImmediately' is true or
-  // after every 'maxFetchFailuresBeforeReporting' failures
-  private void checkAndInformJobTracker(
-      int failures, TaskAttemptID mapId, boolean readError, 
-      boolean connectExcpt) {
-    if (connectExcpt || (reportReadErrorImmediately && readError)
-        || ((failures % maxFetchFailuresBeforeReporting) == 0)) {
-      LOG.info("Reporting fetch failure for " + mapId + " to jobtracker.");
-      status.addFetchFailedMap((org.apache.hadoop.mapred.TaskAttemptID) mapId);
-    }
-  }
-    
-  private void checkReducerHealth() {
-    final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
-    final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
-    final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;
-
-    long totalFailures = failedShuffleCounter.getValue();
-    int doneMaps = totalMaps - remainingMaps;
-    
-    boolean reducerHealthy =
-      (((float)totalFailures / (totalFailures + doneMaps))
-          < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
-    
-    // check if the reducer has progressed enough
-    boolean reducerProgressedEnough =
-      (((float)doneMaps / totalMaps)
-          >= MIN_REQUIRED_PROGRESS_PERCENT);
-
-    // check if the reducer is stalled for a long time
-    // duration for which the reducer is stalled
-    int stallDuration =
-      (int)(System.currentTimeMillis() - lastProgressTime);
-    
-    // duration for which the reducer ran with progress
-    int shuffleProgressDuration =
-      (int)(lastProgressTime - startTime);
-
-    // min time the reducer should run without getting killed
-    int minShuffleRunDuration =
-      (shuffleProgressDuration > maxMapRuntime)
-      ? shuffleProgressDuration
-          : maxMapRuntime;
-    
-    boolean reducerStalled =
-      (((float)stallDuration / minShuffleRunDuration)
-          >= MAX_ALLOWED_STALL_TIME_PERCENT);
-
-    // kill if not healthy and has insufficient progress
-    if ((failureCounts.size() >= maxFailedUniqueFetches ||
-        failureCounts.size() == (totalMaps - doneMaps))
-        && !reducerHealthy
-        && (!reducerProgressedEnough || reducerStalled)) {
-      LOG.fatal("Shuffle failed with too many fetch failures " +
-      "and insufficient progress!");
-      String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
-      reporter.reportException(new IOException(errorMsg));
-    }
-
-  }
-  
-  public synchronized void tipFailed(TaskID taskId) {
-    if (!finishedMaps[taskId.getId()]) {
-      finishedMaps[taskId.getId()] = true;
-      if (--remainingMaps == 0) {
-        notifyAll();
-      }
-      updateStatus();
-    }
-  }
-  
-  public synchronized void addKnownMapOutput(String hostName, 
-                                             String hostUrl,
-                                             TaskAttemptID mapId) {
-    MapHost host = mapLocations.get(hostName);
-    if (host == null) {
-      host = new MapHost(hostName, hostUrl);
-      mapLocations.put(hostName, host);
-    }
-    host.addKnownMap(mapId);
-
-    // Mark the host as pending 
-    if (host.getState() == State.PENDING) {
-      pendingHosts.add(host);
-      notifyAll();
-    }
-  }
-  
-  public synchronized void obsoleteMapOutput(TaskAttemptID mapId) {
-    obsoleteMaps.add(mapId);
-  }
-  
-  public synchronized void putBackKnownMapOutput(MapHost host, 
-                                                 TaskAttemptID mapId) {
-    host.addKnownMap(mapId);
-  }
-
-  public synchronized MapHost getHost() throws InterruptedException {
-      while(pendingHosts.isEmpty()) {
-        wait();
-      }
-      
-      MapHost host = null;
-      Iterator<MapHost> iter = pendingHosts.iterator();
-      int numToPick = random.nextInt(pendingHosts.size());
-      for (int i=0; i <= numToPick; ++i) {
-        host = iter.next();
-      }
-      
-      pendingHosts.remove(host);     
-      host.markBusy();
-      
-      LOG.info("Assiging " + host + " with " + host.getNumKnownMapOutputs() + 
-               " to " + Thread.currentThread().getName());
-      shuffleStart.set(System.currentTimeMillis());
-      
-      return host;
-  }
-  
-  public synchronized List<TaskAttemptID> getMapsForHost(MapHost host) {
-    List<TaskAttemptID> list = host.getAndClearKnownMaps();
-    Iterator<TaskAttemptID> itr = list.iterator();
-    List<TaskAttemptID> result = new ArrayList<TaskAttemptID>();
-    int includedMaps = 0;
-    int totalSize = list.size();
-    // find the maps that we still need, up to the limit
-    while (itr.hasNext()) {
-      TaskAttemptID id = itr.next();
-      if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
-        result.add(id);
-        if (++includedMaps >= MAX_MAPS_AT_ONCE) {
-          break;
-        }
-      }
-    }
-    // put back the maps left after the limit
-    while (itr.hasNext()) {
-      TaskAttemptID id = itr.next();
-      if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
-        host.addKnownMap(id);
-      }
-    }
-    LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
-             host + " to " + Thread.currentThread().getName());
-    return result;
-  }
-
-  public synchronized void freeHost(MapHost host) {
-    if (host.getState() != State.PENALIZED) {
-      if (host.markAvailable() == State.PENDING) {
-        pendingHosts.add(host);
-        notifyAll();
-      }
-    }
-    LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " + 
-             (System.currentTimeMillis()-shuffleStart.get()) + "ms");
-  }
-    
-  public synchronized void resetKnownMaps() {
-    mapLocations.clear();
-    obsoleteMaps.clear();
-    pendingHosts.clear();
-  }
-  
   /**
    * Wait until the shuffle finishes or until the timeout.
    * @param millis maximum wait time
    * @return true if the shuffle is done
    * @throws InterruptedException
    */
-  public synchronized boolean waitUntilDone(int millis
-                                            ) throws InterruptedException {
-    if (remainingMaps > 0) {
-      wait(millis);
-      return remainingMaps == 0;
-    }
-    return true;
-  }
-  
-  /**
-   * A structure that records the penalty for a host.
-   */
-  private static class Penalty implements Delayed {
-    MapHost host;
-    private long endTime;
-    
-    Penalty(MapHost host, long delay) {
-      this.host = host;
-      this.endTime = System.currentTimeMillis() + delay;
-    }
-
-    public long getDelay(TimeUnit unit) {
-      long remainingTime = endTime - System.currentTimeMillis();
-      return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
-    }
+  public boolean waitUntilDone(int millis) throws InterruptedException;
 
-    public int compareTo(Delayed o) {
-      long other = ((Penalty) o).endTime;
-      return endTime == other ? 0 : (endTime < other ? -1 : 1);
-    }
-    
-  }
-  
   /**
-   * A thread that takes hosts off of the penalty list when the timer expires.
+   * Interpret a {@link TaskCompletionEvent} from the event stream.
+   * @param tce Intermediate output metadata
    */
-  private class Referee extends Thread {
-    public Referee() {
-      setName("ShufflePenaltyReferee");
-      setDaemon(true);
-    }
+  public void resolve(TaskCompletionEvent tce)
+    throws IOException, InterruptedException;
 
-    public void run() {
-      try {
-        while (true) {
-          // take the first host that has an expired penalty
-          MapHost host = penalties.take().host;
-          synchronized (ShuffleScheduler.this) {
-            if (host.markAvailable() == MapHost.State.PENDING) {
-              pendingHosts.add(host);
-              ShuffleScheduler.this.notifyAll();
-            }
-          }
-        }
-      } catch (InterruptedException ie) {
-        return;
-      } catch (Throwable t) {
-        reporter.reportException(t);
-      }
-    }
-  }
-  
-  public void close() throws InterruptedException {
-    referee.interrupt();
-    referee.join();
-  }
+  public void close() throws InterruptedException;
 
-  public synchronized void informMaxMapRunTime(int duration) {
-    if (duration > maxMapRuntime) {
-      maxMapRuntime = duration;
-    }
-  }
 }

+ 498 - 0
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/ShuffleSchedulerImpl.java

@@ -0,0 +1,498 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.mapreduce.task.reduce;
+
+import java.io.IOException;
+
+import java.net.URI;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.DelayQueue;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.TaskCompletionEvent;
+import org.apache.hadoop.mapred.TaskStatus;
+import org.apache.hadoop.mapreduce.MRJobConfig;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.hadoop.mapreduce.TaskID;
+import org.apache.hadoop.mapreduce.task.reduce.MapHost.State;
+import org.apache.hadoop.util.Progress;
+
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
+  static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
+    protected Long initialValue() {
+      return 0L;
+    }
+  };
+
+  private static final Log LOG = LogFactory.getLog(ShuffleSchedulerImpl.class);
+  private static final int MAX_MAPS_AT_ONCE = 20;
+  private static final long INITIAL_PENALTY = 10000;
+  private static final float PENALTY_GROWTH_RATE = 1.3f;
+  private final static int REPORT_FAILURE_LIMIT = 10;
+
+  private final boolean[] finishedMaps;
+
+  private final int totalMaps;
+  private int remainingMaps;
+  private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
+  private Set<MapHost> pendingHosts = new HashSet<MapHost>();
+  private Set<TaskAttemptID> obsoleteMaps = new HashSet<TaskAttemptID>();
+
+  private final TaskAttemptID reduceId;
+  private final Random random = new Random();
+  private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
+  private final Referee referee = new Referee();
+  private final Map<TaskAttemptID,IntWritable> failureCounts =
+    new HashMap<TaskAttemptID,IntWritable>();
+  private final Map<String,IntWritable> hostFailures =
+    new HashMap<String,IntWritable>();
+  private final TaskStatus status;
+  private final ExceptionReporter reporter;
+  private final int abortFailureLimit;
+  private final Progress progress;
+  private final Counters.Counter shuffledMapsCounter;
+  private final Counters.Counter reduceShuffleBytes;
+  private final Counters.Counter failedShuffleCounter;
+
+  private final long startTime;
+  private long lastProgressTime;
+
+  private volatile int maxMapRuntime = 0;
+  private final int maxFailedUniqueFetches;
+  private final int maxFetchFailuresBeforeReporting;
+
+  private long totalBytesShuffledTillNow = 0;
+  private final DecimalFormat mbpsFormat = new DecimalFormat("0.00");
+
+  private final boolean reportReadErrorImmediately;
+  private long maxDelay = MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY;
+
+  public ShuffleSchedulerImpl(JobConf job, TaskStatus status,
+                          TaskAttemptID reduceId,
+                          ExceptionReporter reporter,
+                          Progress progress,
+                          Counters.Counter shuffledMapsCounter,
+                          Counters.Counter reduceShuffleBytes,
+                          Counters.Counter failedShuffleCounter) {
+    totalMaps = job.getNumMapTasks();
+    abortFailureLimit = Math.max(30, totalMaps / 10);
+
+    remainingMaps = totalMaps;
+    finishedMaps = new boolean[remainingMaps];
+    this.reporter = reporter;
+    this.status = status;
+    this.reduceId = reduceId;
+    this.progress = progress;
+    this.shuffledMapsCounter = shuffledMapsCounter;
+    this.reduceShuffleBytes = reduceShuffleBytes;
+    this.failedShuffleCounter = failedShuffleCounter;
+    this.startTime = System.currentTimeMillis();
+    lastProgressTime = startTime;
+    referee.start();
+    this.maxFailedUniqueFetches = Math.min(totalMaps, 5);
+    this.maxFetchFailuresBeforeReporting = job.getInt(
+        MRJobConfig.SHUFFLE_FETCH_FAILURES, REPORT_FAILURE_LIMIT);
+    this.reportReadErrorImmediately = job.getBoolean(
+        MRJobConfig.SHUFFLE_NOTIFY_READERROR, true);
+
+    this.maxDelay = job.getLong(MRJobConfig.MAX_SHUFFLE_FETCH_RETRY_DELAY,
+        MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY);
+  }
+
+  @Override
+  public void resolve(TaskCompletionEvent event) {
+    switch (event.getTaskStatus()) {
+    case SUCCEEDED:
+      URI u = getBaseURI(reduceId, event.getTaskTrackerHttp());
+      addKnownMapOutput(u.getHost() + ":" + u.getPort(),
+          u.toString(),
+          event.getTaskAttemptId());
+      maxMapRuntime = Math.max(maxMapRuntime, event.getTaskRunTime());
+      break;
+    case FAILED:
+    case KILLED:
+    case OBSOLETE:
+      obsoleteMapOutput(event.getTaskAttemptId());
+      LOG.info("Ignoring obsolete output of " + event.getTaskStatus() +
+          " map-task: '" + event.getTaskAttemptId() + "'");
+      break;
+    case TIPFAILED:
+      tipFailed(event.getTaskAttemptId().getTaskID());
+      LOG.info("Ignoring output of failed map TIP: '" +
+          event.getTaskAttemptId() + "'");
+      break;
+    }
+  }
+
+  static URI getBaseURI(TaskAttemptID reduceId, String url) {
+    StringBuffer baseUrl = new StringBuffer(url);
+    if (!url.endsWith("/")) {
+      baseUrl.append("/");
+    }
+    baseUrl.append("mapOutput?job=");
+    baseUrl.append(reduceId.getJobID());
+    baseUrl.append("&reduce=");
+    baseUrl.append(reduceId.getTaskID().getId());
+    baseUrl.append("&map=");
+    URI u = URI.create(baseUrl.toString());
+    return u;
+  }
+
+  public synchronized void copySucceeded(TaskAttemptID mapId,
+                                         MapHost host,
+                                         long bytes,
+                                         long millis,
+                                         MapOutput<K,V> output
+                                         ) throws IOException {
+    failureCounts.remove(mapId);
+    hostFailures.remove(host.getHostName());
+    int mapIndex = mapId.getTaskID().getId();
+
+    if (!finishedMaps[mapIndex]) {
+      output.commit();
+      finishedMaps[mapIndex] = true;
+      shuffledMapsCounter.increment(1);
+      if (--remainingMaps == 0) {
+        notifyAll();
+      }
+
+      // update the status
+      totalBytesShuffledTillNow += bytes;
+      updateStatus();
+      reduceShuffleBytes.increment(bytes);
+      lastProgressTime = System.currentTimeMillis();
+      LOG.debug("map " + mapId + " done " + status.getStateString());
+    }
+  }
+
+  private void updateStatus() {
+    float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
+    int mapsDone = totalMaps - remainingMaps;
+    long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;
+
+    float transferRate = mbs / secsSinceStart;
+    progress.set((float) mapsDone / totalMaps);
+    String statusString = mapsDone + " / " + totalMaps + " copied.";
+    status.setStateString(statusString);
+
+    progress.setStatus("copy(" + mapsDone + " of " + totalMaps + " at "
+        + mbpsFormat.format(transferRate) + " MB/s)");
+  }
+
+  public synchronized void copyFailed(TaskAttemptID mapId, MapHost host,
+                                      boolean readError, boolean connectExcpt) {
+    host.penalize();
+    int failures = 1;
+    if (failureCounts.containsKey(mapId)) {
+      IntWritable x = failureCounts.get(mapId);
+      x.set(x.get() + 1);
+      failures = x.get();
+    } else {
+      failureCounts.put(mapId, new IntWritable(1));
+    }
+    String hostname = host.getHostName();
+    if (hostFailures.containsKey(hostname)) {
+      IntWritable x = hostFailures.get(hostname);
+      x.set(x.get() + 1);
+    } else {
+      hostFailures.put(hostname, new IntWritable(1));
+    }
+    if (failures >= abortFailureLimit) {
+      try {
+        throw new IOException(failures + " failures downloading " + mapId);
+      } catch (IOException ie) {
+        reporter.reportException(ie);
+      }
+    }
+
+    checkAndInformJobTracker(failures, mapId, readError, connectExcpt);
+
+    checkReducerHealth();
+
+    long delay = (long) (INITIAL_PENALTY *
+        Math.pow(PENALTY_GROWTH_RATE, failures));
+    if (delay > maxDelay) {
+      delay = maxDelay;
+    }
+
+    penalties.add(new Penalty(host, delay));
+
+    failedShuffleCounter.increment(1);
+  }
+
+  // Notify the JobTracker
+  // after every read error, if 'reportReadErrorImmediately' is true or
+  // after every 'maxFetchFailuresBeforeReporting' failures
+  private void checkAndInformJobTracker(
+      int failures, TaskAttemptID mapId, boolean readError,
+      boolean connectExcpt) {
+    if (connectExcpt || (reportReadErrorImmediately && readError)
+        || ((failures % maxFetchFailuresBeforeReporting) == 0)) {
+      LOG.info("Reporting fetch failure for " + mapId + " to jobtracker.");
+      status.addFetchFailedMap((org.apache.hadoop.mapred.TaskAttemptID) mapId);
+    }
+  }
+
+  private void checkReducerHealth() {
+    final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
+    final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
+    final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;
+
+    long totalFailures = failedShuffleCounter.getValue();
+    int doneMaps = totalMaps - remainingMaps;
+
+    boolean reducerHealthy =
+      (((float)totalFailures / (totalFailures + doneMaps))
+          < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
+
+    // check if the reducer has progressed enough
+    boolean reducerProgressedEnough =
+      (((float)doneMaps / totalMaps)
+          >= MIN_REQUIRED_PROGRESS_PERCENT);
+
+    // check if the reducer is stalled for a long time
+    // duration for which the reducer is stalled
+    int stallDuration =
+      (int)(System.currentTimeMillis() - lastProgressTime);
+
+    // duration for which the reducer ran with progress
+    int shuffleProgressDuration =
+      (int)(lastProgressTime - startTime);
+
+    // min time the reducer should run without getting killed
+    int minShuffleRunDuration =
+      Math.max(shuffleProgressDuration, maxMapRuntime);
+
+    boolean reducerStalled =
+      (((float)stallDuration / minShuffleRunDuration)
+          >= MAX_ALLOWED_STALL_TIME_PERCENT);
+
+    // kill if not healthy and has insufficient progress
+    if ((failureCounts.size() >= maxFailedUniqueFetches ||
+        failureCounts.size() == (totalMaps - doneMaps))
+        && !reducerHealthy
+        && (!reducerProgressedEnough || reducerStalled)) {
+      LOG.fatal("Shuffle failed with too many fetch failures " +
+      "and insufficient progress!");
+      String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
+      reporter.reportException(new IOException(errorMsg));
+    }
+
+  }
+
+  public synchronized void tipFailed(TaskID taskId) {
+    if (!finishedMaps[taskId.getId()]) {
+      finishedMaps[taskId.getId()] = true;
+      if (--remainingMaps == 0) {
+        notifyAll();
+      }
+      updateStatus();
+    }
+  }
+
+  public synchronized void addKnownMapOutput(String hostName,
+                                             String hostUrl,
+                                             TaskAttemptID mapId) {
+    MapHost host = mapLocations.get(hostName);
+    if (host == null) {
+      host = new MapHost(hostName, hostUrl);
+      mapLocations.put(hostName, host);
+    }
+    host.addKnownMap(mapId);
+
+    // Mark the host as pending
+    if (host.getState() == State.PENDING) {
+      pendingHosts.add(host);
+      notifyAll();
+    }
+  }
+
+
+  public synchronized void obsoleteMapOutput(TaskAttemptID mapId) {
+    obsoleteMaps.add(mapId);
+  }
+
+  public synchronized void putBackKnownMapOutput(MapHost host,
+                                                 TaskAttemptID mapId) {
+    host.addKnownMap(mapId);
+  }
+
+
+  public synchronized MapHost getHost() throws InterruptedException {
+      while(pendingHosts.isEmpty()) {
+        wait();
+      }
+
+      MapHost host = null;
+      Iterator<MapHost> iter = pendingHosts.iterator();
+      int numToPick = random.nextInt(pendingHosts.size());
+      for (int i=0; i <= numToPick; ++i) {
+        host = iter.next();
+      }
+
+      pendingHosts.remove(host);
+      host.markBusy();
+
+      LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
+               " to " + Thread.currentThread().getName());
+      shuffleStart.set(System.currentTimeMillis());
+
+      return host;
+  }
+
+  public synchronized List<TaskAttemptID> getMapsForHost(MapHost host) {
+    List<TaskAttemptID> list = host.getAndClearKnownMaps();
+    Iterator<TaskAttemptID> itr = list.iterator();
+    List<TaskAttemptID> result = new ArrayList<TaskAttemptID>();
+    int includedMaps = 0;
+    int totalSize = list.size();
+    // find the maps that we still need, up to the limit
+    while (itr.hasNext()) {
+      TaskAttemptID id = itr.next();
+      if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
+        result.add(id);
+        if (++includedMaps >= MAX_MAPS_AT_ONCE) {
+          break;
+        }
+      }
+    }
+    // put back the maps left after the limit
+    while (itr.hasNext()) {
+      TaskAttemptID id = itr.next();
+      if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
+        host.addKnownMap(id);
+      }
+    }
+    LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
+             host + " to " + Thread.currentThread().getName());
+    return result;
+  }
+
+  public synchronized void freeHost(MapHost host) {
+    if (host.getState() != State.PENALIZED) {
+      if (host.markAvailable() == State.PENDING) {
+        pendingHosts.add(host);
+        notifyAll();
+      }
+    }
+    LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
+             (System.currentTimeMillis()-shuffleStart.get()) + "ms");
+  }
+
+  public synchronized void resetKnownMaps() {
+    mapLocations.clear();
+    obsoleteMaps.clear();
+    pendingHosts.clear();
+  }
+
+  /**
+   * Wait until the shuffle finishes or until the timeout.
+   * @param millis maximum wait time
+   * @return true if the shuffle is done
+   * @throws InterruptedException
+   */
+  @Override
+  public synchronized boolean waitUntilDone(int millis
+                                            ) throws InterruptedException {
+    if (remainingMaps > 0) {
+      wait(millis);
+      return remainingMaps == 0;
+    }
+    return true;
+  }
+
+  /**
+   * A structure that records the penalty for a host.
+   */
+  private static class Penalty implements Delayed {
+    MapHost host;
+    private long endTime;
+
+    Penalty(MapHost host, long delay) {
+      this.host = host;
+      this.endTime = System.currentTimeMillis() + delay;
+    }
+
+    @Override
+    public long getDelay(TimeUnit unit) {
+      long remainingTime = endTime - System.currentTimeMillis();
+      return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
+    }
+
+    @Override
+    public int compareTo(Delayed o) {
+      long other = ((Penalty) o).endTime;
+      return endTime == other ? 0 : (endTime < other ? -1 : 1);
+    }
+
+  }
+
+  /**
+   * A thread that takes hosts off of the penalty list when the timer expires.
+   */
+  private class Referee extends Thread {
+    public Referee() {
+      setName("ShufflePenaltyReferee");
+      setDaemon(true);
+    }
+
+    public void run() {
+      try {
+        while (true) {
+          // take the first host that has an expired penalty
+          MapHost host = penalties.take().host;
+          synchronized (ShuffleSchedulerImpl.this) {
+            if (host.markAvailable() == MapHost.State.PENDING) {
+              pendingHosts.add(host);
+              ShuffleSchedulerImpl.this.notifyAll();
+            }
+          }
+        }
+      } catch (InterruptedException ie) {
+        return;
+      } catch (Throwable t) {
+        reporter.reportException(t);
+      }
+    }
+  }
+
+  @Override
+  public void close() throws InterruptedException {
+    referee.interrupt();
+    referee.join();
+  }
+
+}

+ 8 - 5
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestEventFetcher.java

@@ -43,7 +43,8 @@ import org.mockito.InOrder;
 public class TestEventFetcher {
 
   @Test
-  public void testConsecutiveFetch() throws IOException {
+  public void testConsecutiveFetch()
+      throws IOException, InterruptedException {
     final int MAX_EVENTS_TO_FETCH = 100;
     TaskAttemptID tid = new TaskAttemptID("12345", 1, TaskType.REDUCE, 1, 1);
 
@@ -63,7 +64,8 @@ public class TestEventFetcher {
       .thenReturn(getMockedCompletionEventsUpdate(MAX_EVENTS_TO_FETCH*2, 3));
 
     @SuppressWarnings("unchecked")
-    ShuffleScheduler<String,String> scheduler = mock(ShuffleScheduler.class);
+    ShuffleScheduler<String,String> scheduler =
+      mock(ShuffleScheduler.class);
     ExceptionReporter reporter = mock(ExceptionReporter.class);
 
     EventFetcherForTest<String,String> ef =
@@ -79,8 +81,8 @@ public class TestEventFetcher {
         eq(MAX_EVENTS_TO_FETCH), eq(MAX_EVENTS_TO_FETCH), eq(tid));
     inOrder.verify(umbilical).getMapCompletionEvents(any(JobID.class),
         eq(MAX_EVENTS_TO_FETCH*2), eq(MAX_EVENTS_TO_FETCH), eq(tid));
-    verify(scheduler, times(MAX_EVENTS_TO_FETCH*2 + 3)).addKnownMapOutput(
-        anyString(), anyString(), any(TaskAttemptID.class));
+    verify(scheduler, times(MAX_EVENTS_TO_FETCH*2 + 3)).resolve(
+        any(TaskCompletionEvent.class));
   }
 
   private MapTaskCompletionEventsUpdate getMockedCompletionEventsUpdate(
@@ -108,7 +110,8 @@ public class TestEventFetcher {
     }
 
     @Override
-    public int getMapCompletionEvents() throws IOException {
+    public int getMapCompletionEvents()
+        throws IOException, InterruptedException {
       return super.getMapCompletionEvents();
     }
 

+ 9 - 9
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java

@@ -56,9 +56,10 @@ public class TestFetcher {
     private HttpURLConnection connection;
 
     public FakeFetcher(JobConf job, TaskAttemptID reduceId,
-        ShuffleScheduler<K,V> scheduler, MergeManagerImpl<K,V> merger, Reporter reporter,
-        ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter,
-        SecretKey jobTokenSecret, HttpURLConnection connection) {
+        ShuffleSchedulerImpl<K,V> scheduler, MergeManagerImpl<K,V> merger,
+        Reporter reporter, ShuffleClientMetrics metrics,
+        ExceptionReporter exceptionReporter, SecretKey jobTokenSecret,
+        HttpURLConnection connection) {
       super(job, reduceId, scheduler, merger, reporter, metrics, exceptionReporter,
           jobTokenSecret);
       this.connection = connection;
@@ -79,7 +80,7 @@ public class TestFetcher {
     LOG.info("testCopyFromHostConnectionTimeout");
     JobConf job = new JobConf();
     TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1");
-    ShuffleScheduler<Text, Text> ss = mock(ShuffleScheduler.class);
+    ShuffleSchedulerImpl<Text, Text> ss = mock(ShuffleSchedulerImpl.class);
     MergeManagerImpl<Text, Text> mm = mock(MergeManagerImpl.class);
     Reporter r = mock(Reporter.class);
     ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class);
@@ -127,7 +128,7 @@ public class TestFetcher {
     LOG.info("testCopyFromHostBogusHeader");
     JobConf job = new JobConf();
     TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1");
-    ShuffleScheduler<Text, Text> ss = mock(ShuffleScheduler.class);
+    ShuffleSchedulerImpl<Text, Text> ss = mock(ShuffleSchedulerImpl.class);
     MergeManagerImpl<Text, Text> mm = mock(MergeManagerImpl.class);
     Reporter r = mock(Reporter.class);
     ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class);
@@ -182,7 +183,7 @@ public class TestFetcher {
     LOG.info("testCopyFromHostWait");
     JobConf job = new JobConf();
     TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1");
-    ShuffleScheduler<Text, Text> ss = mock(ShuffleScheduler.class);
+    ShuffleSchedulerImpl<Text, Text> ss = mock(ShuffleSchedulerImpl.class);
     MergeManagerImpl<Text, Text> mm = mock(MergeManagerImpl.class);
     Reporter r = mock(Reporter.class);
     ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class);
@@ -240,7 +241,7 @@ public class TestFetcher {
     LOG.info("testCopyFromHostWaitExtraBytes");
     JobConf job = new JobConf();
     TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1");
-    ShuffleScheduler<Text, Text> ss = mock(ShuffleScheduler.class);
+    ShuffleSchedulerImpl<Text, Text> ss = mock(ShuffleSchedulerImpl.class);
     MergeManagerImpl<Text, Text> mm = mock(MergeManagerImpl.class);
     InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class);
 
@@ -256,7 +257,6 @@ public class TestFetcher {
     
     Fetcher<Text,Text> underTest = new FakeFetcher<Text,Text>(job, id, ss, mm,
         r, metrics, except, key, connection);
-    
 
     MapHost host = new MapHost("localhost", "http://localhost:8080/");
     
@@ -315,7 +315,7 @@ public class TestFetcher {
     LOG.info("testCopyFromHostCompressFailure");
     JobConf job = new JobConf();
     TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1");
-    ShuffleScheduler<Text, Text> ss = mock(ShuffleScheduler.class);
+    ShuffleSchedulerImpl<Text, Text> ss = mock(ShuffleSchedulerImpl.class);
     MergeManagerImpl<Text, Text> mm = mock(MergeManagerImpl.class);
     InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class);
     Reporter r = mock(Reporter.class);

+ 4 - 2
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestShuffleScheduler.java

@@ -47,8 +47,10 @@ public class TestShuffleScheduler {
     };
     Progress progress = new Progress();
 
-    ShuffleScheduler scheduler = new ShuffleScheduler(job, status, null,
-        progress, null, null, null);
+    TaskAttemptID reduceId = new TaskAttemptID("314159", 0, TaskType.REDUCE,
+        0, 0);
+    ShuffleSchedulerImpl scheduler = new ShuffleSchedulerImpl(job, status,
+        reduceId, null, progress, null, null, null);
 
     JobID jobId = new JobID();
     TaskID taskId1 = new TaskID(jobId, TaskType.REDUCE, 1);