Selaa lähdekoodia

MAPREDUCE-5891. Improved shuffle error handling across NM restarts. Contributed by Junping Du
(cherry picked from commit 2c3da25fd718b3a9c1ed67f05b577975ae613f4e)

Jason Lowe 10 vuotta sitten
vanhempi
commit
44c7113133

+ 3 - 0
hadoop-mapreduce-project/CHANGES.txt

@@ -34,6 +34,9 @@ Release 2.6.0 - UNRELEASED
     MAPREDUCE-5130. Add missing job config options to mapred-default.xml
     (Ray Chiang via Sandy Ryza)
 
+    MAPREDUCE-5891. Improved shuffle error handling across NM restarts
+    (Junping Du via jlowe)
+
   OPTIMIZATIONS
 
   BUG FIXES

+ 8 - 0
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/MRJobConfig.java

@@ -293,6 +293,14 @@ public interface MRJobConfig {
   
   public static final String MAX_FETCH_FAILURES_NOTIFICATIONS = "mapreduce.reduce.shuffle.max-fetch-failures-notifications";
   public static final int DEFAULT_MAX_FETCH_FAILURES_NOTIFICATIONS = 3;
+  
+  public static final String SHUFFLE_FETCH_RETRY_INTERVAL_MS = "mapreduce.reduce.shuffle.fetch.retry.interval-ms";
+  /** Default interval that fetcher retry to fetch during NM restart.*/
+  public final static int DEFAULT_SHUFFLE_FETCH_RETRY_INTERVAL_MS = 1000;
+  
+  public static final String SHUFFLE_FETCH_RETRY_TIMEOUT_MS = "mapreduce.reduce.shuffle.fetch.retry.timeout-ms";
+  
+  public static final String SHUFFLE_FETCH_RETRY_ENABLED = "mapreduce.reduce.shuffle.fetch.retry.enabled";
 
   public static final String SHUFFLE_NOTIFY_READERROR = "mapreduce.reduce.shuffle.notify.readerror";
   

+ 176 - 55
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java

@@ -27,6 +27,7 @@ import java.net.URL;
 import java.net.URLConnection;
 import java.security.GeneralSecurityException;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -46,6 +47,8 @@ import org.apache.hadoop.mapreduce.TaskAttemptID;
 import org.apache.hadoop.mapreduce.security.SecureShuffleUtils;
 import org.apache.hadoop.mapreduce.CryptoUtils;
 import org.apache.hadoop.security.ssl.SSLFactory;
+import org.apache.hadoop.util.Time;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
 
 import com.google.common.annotations.VisibleForTesting;
 
@@ -85,10 +88,18 @@ class Fetcher<K,V> extends Thread {
   private final int connectionTimeout;
   private final int readTimeout;
   
+  private final int fetchRetryTimeout;
+  private final int fetchRetryInterval;
+  
+  private final boolean fetchRetryEnabled;
+  
   private final SecretKey shuffleSecretKey;
 
   protected HttpURLConnection connection;
   private volatile boolean stopped = false;
+  
+  // Initiative value is 0, which means it hasn't retried yet.
+  private long retryStartTime = 0;
 
   private static boolean sslShuffle;
   private static SSLFactory sslFactory;
@@ -135,6 +146,19 @@ class Fetcher<K,V> extends Thread {
     this.readTimeout = 
       job.getInt(MRJobConfig.SHUFFLE_READ_TIMEOUT, DEFAULT_READ_TIMEOUT);
     
+    this.fetchRetryInterval = job.getInt(MRJobConfig.SHUFFLE_FETCH_RETRY_INTERVAL_MS,
+        MRJobConfig.DEFAULT_SHUFFLE_FETCH_RETRY_INTERVAL_MS);
+    
+    this.fetchRetryTimeout = job.getInt(MRJobConfig.SHUFFLE_FETCH_RETRY_TIMEOUT_MS, 
+        DEFAULT_STALLED_COPY_TIMEOUT);
+    
+    boolean shuffleFetchEnabledDefault = job.getBoolean(
+        YarnConfiguration.NM_RECOVERY_ENABLED, 
+        YarnConfiguration.DEFAULT_NM_RECOVERY_ENABLED);
+    this.fetchRetryEnabled = job.getBoolean(
+        MRJobConfig.SHUFFLE_FETCH_RETRY_ENABLED, 
+        shuffleFetchEnabledDefault);
+    
     setName("fetcher#" + id);
     setDaemon(true);
 
@@ -242,6 +266,8 @@ class Fetcher<K,V> extends Thread {
    */
   @VisibleForTesting
   protected void copyFromHost(MapHost host) throws IOException {
+    // reset retryStartTime for a new host
+    retryStartTime = 0;
     // Get completed maps on 'host'
     List<TaskAttemptID> maps = scheduler.getMapsForHost(host);
     
@@ -261,60 +287,14 @@ class Fetcher<K,V> extends Thread {
     
     // Construct the url and connect
     DataInputStream input = null;
+    URL url = getMapOutputURL(host, maps);
     try {
-      URL url = getMapOutputURL(host, maps);
-      openConnection(url);
-      if (stopped) {
-        abortConnect(host, remaining);
-        return;
-      }
+      setupConnectionsWithRetry(host, remaining, url);
       
-      // generate hash of the url
-      String msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
-      String encHash = SecureShuffleUtils.hashFromString(msgToEncode,
-          shuffleSecretKey);
-      
-      // put url hash into http header
-      connection.addRequestProperty(
-          SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
-      // set the read timeout
-      connection.setReadTimeout(readTimeout);
-      // put shuffle version into http header
-      connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
-          ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
-      connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
-          ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
-      connect(connection, connectionTimeout);
-      // verify that the thread wasn't stopped during calls to connect
       if (stopped) {
         abortConnect(host, remaining);
         return;
       }
-      input = new DataInputStream(connection.getInputStream());
-
-      // Validate response code
-      int rc = connection.getResponseCode();
-      if (rc != HttpURLConnection.HTTP_OK) {
-        throw new IOException(
-            "Got invalid response code " + rc + " from " + url +
-            ": " + connection.getResponseMessage());
-      }
-      // get the shuffle version
-      if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
-          connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
-          || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
-              connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
-        throw new IOException("Incompatible shuffle response version");
-      }
-      // get the replyHash which is HMac of the encHash we sent to the server
-      String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
-      if(replyHash==null) {
-        throw new IOException("security validation of TT Map output failed");
-      }
-      LOG.debug("url="+msgToEncode+";encHash="+encHash+";replyHash="+replyHash);
-      // verify that replyHash is HMac of encHash
-      SecureShuffleUtils.verifyReply(replyHash, encHash, shuffleSecretKey);
-      LOG.info("for url="+msgToEncode+" sent hash and received reply");
     } catch (IOException ie) {
       boolean connectExcpt = ie instanceof ConnectException;
       ioErrs.increment(1);
@@ -336,6 +316,8 @@ class Fetcher<K,V> extends Thread {
       return;
     }
     
+    input = new DataInputStream(connection.getInputStream());
+    
     try {
       // Loop through available map-outputs and fetch them
       // On any error, faildTasks is not null and we exit
@@ -343,7 +325,23 @@ class Fetcher<K,V> extends Thread {
       // yet_to_be_fetched list and marking the failed tasks.
       TaskAttemptID[] failedTasks = null;
       while (!remaining.isEmpty() && failedTasks == null) {
-        failedTasks = copyMapOutput(host, input, remaining);
+        try {
+          failedTasks = copyMapOutput(host, input, remaining, fetchRetryEnabled);
+        } catch (IOException e) {
+          //
+          // Setup connection again if disconnected by NM
+          connection.disconnect();
+          // Get map output from remaining tasks only.
+          url = getMapOutputURL(host, remaining);
+          
+          // Connect with retry as expecting host's recovery take sometime.
+          setupConnectionsWithRetry(host, remaining, url);
+          if (stopped) {
+            abortConnect(host, remaining);
+            return;
+          }
+          input = new DataInputStream(connection.getInputStream());
+        }
       }
       
       if(failedTasks != null && failedTasks.length > 0) {
@@ -371,19 +369,111 @@ class Fetcher<K,V> extends Thread {
       }
     }
   }
+
+  private void setupConnectionsWithRetry(MapHost host,
+      Set<TaskAttemptID> remaining, URL url) throws IOException {
+    openConnectionWithRetry(host, remaining, url);
+    if (stopped) {
+      return;
+    }
+      
+    // generate hash of the url
+    String msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
+    String encHash = SecureShuffleUtils.hashFromString(msgToEncode,
+        shuffleSecretKey);
+    
+    setupShuffleConnection(encHash);
+    connect(connection, connectionTimeout);
+    // verify that the thread wasn't stopped during calls to connect
+    if (stopped) {
+      return;
+    }
+    
+    verifyConnection(url, msgToEncode, encHash);
+  }
+
+  private void openConnectionWithRetry(MapHost host,
+      Set<TaskAttemptID> remaining, URL url) throws IOException {
+    long startTime = Time.monotonicNow();
+    boolean shouldWait = true;
+    while (shouldWait) {
+      try {
+        openConnection(url);
+        shouldWait = false;
+      } catch (IOException e) {
+        if (!fetchRetryEnabled) {
+           // throw exception directly if fetch's retry is not enabled
+           throw e;
+        }
+        if ((Time.monotonicNow() - startTime) >= this.fetchRetryTimeout) {
+          LOG.warn("Failed to connect to host: " + url + "after " 
+              + fetchRetryTimeout + "milliseconds.");
+          throw e;
+        }
+        try {
+          Thread.sleep(this.fetchRetryInterval);
+        } catch (InterruptedException e1) {
+          if (stopped) {
+            return;
+          }
+        }
+      }
+    }
+  }
+
+  private void verifyConnection(URL url, String msgToEncode, String encHash)
+      throws IOException {
+    // Validate response code
+    int rc = connection.getResponseCode();
+    if (rc != HttpURLConnection.HTTP_OK) {
+      throw new IOException(
+          "Got invalid response code " + rc + " from " + url +
+          ": " + connection.getResponseMessage());
+    }
+    // get the shuffle version
+    if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
+        connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
+        || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
+            connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
+      throw new IOException("Incompatible shuffle response version");
+    }
+    // get the replyHash which is HMac of the encHash we sent to the server
+    String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
+    if(replyHash==null) {
+      throw new IOException("security validation of TT Map output failed");
+    }
+    LOG.debug("url="+msgToEncode+";encHash="+encHash+";replyHash="+replyHash);
+    // verify that replyHash is HMac of encHash
+    SecureShuffleUtils.verifyReply(replyHash, encHash, shuffleSecretKey);
+    LOG.info("for url="+msgToEncode+" sent hash and received reply");
+  }
+
+  private void setupShuffleConnection(String encHash) {
+    // put url hash into http header
+    connection.addRequestProperty(
+        SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
+    // set the read timeout
+    connection.setReadTimeout(readTimeout);
+    // put shuffle version into http header
+    connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+        ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+    connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+        ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+  }
   
   private static TaskAttemptID[] EMPTY_ATTEMPT_ID_ARRAY = new TaskAttemptID[0];
   
   private TaskAttemptID[] copyMapOutput(MapHost host,
                                 DataInputStream input,
-                                Set<TaskAttemptID> remaining) {
+                                Set<TaskAttemptID> remaining,
+                                boolean canRetry) throws IOException {
     MapOutput<K,V> mapOutput = null;
     TaskAttemptID mapId = null;
     long decompressedLength = -1;
     long compressedLength = -1;
     
     try {
-      long startTime = System.currentTimeMillis();
+      long startTime = Time.monotonicNow();
       int forReduce = -1;
       //Read the shuffle header
       try {
@@ -449,7 +539,10 @@ class Fetcher<K,V> extends Thread {
       }
       
       // Inform the shuffle scheduler
-      long endTime = System.currentTimeMillis();
+      long endTime = Time.monotonicNow();
+      // Reset retryStartTime as map task make progress if retried before.
+      retryStartTime = 0;
+      
       scheduler.copySucceeded(mapId, host, compressedLength, 
                               endTime - startTime, mapOutput);
       // Note successful shuffle
@@ -457,9 +550,14 @@ class Fetcher<K,V> extends Thread {
       metrics.successFetch();
       return null;
     } catch (IOException ioe) {
+      
+      if (canRetry) {
+        checkTimeoutOrRetry(host, ioe);
+      } 
+      
       ioErrs.increment(1);
       if (mapId == null || mapOutput == null) {
-        LOG.info("fetcher#" + id + " failed to read map header" + 
+        LOG.warn("fetcher#" + id + " failed to read map header" + 
                  mapId + " decomp: " + 
                  decompressedLength + ", " + compressedLength, ioe);
         if(mapId == null) {
@@ -468,7 +566,7 @@ class Fetcher<K,V> extends Thread {
           return new TaskAttemptID[] {mapId};
         }
       }
-      
+        
       LOG.warn("Failed to shuffle output of " + mapId + 
                " from " + host.getHostName(), ioe); 
 
@@ -479,6 +577,29 @@ class Fetcher<K,V> extends Thread {
     }
 
   }
+
+  /** check if hit timeout of retry, if not, throw an exception and start a 
+   *  new round of retry.*/
+  private void checkTimeoutOrRetry(MapHost host, IOException ioe)
+      throws IOException {
+    // First time to retry.
+    long currentTime = Time.monotonicNow();
+    if (retryStartTime == 0) {
+       retryStartTime = currentTime;
+    }
+  
+    // Retry is not timeout, let's do retry with throwing an exception.
+    if (currentTime - retryStartTime < this.fetchRetryTimeout) {
+      LOG.warn("Shuffle output from " + host.getHostName() +
+          " failed, retry it.");
+      throw ioe;
+    } else {
+      // timeout, prepare to be failed.
+      LOG.warn("Timeout for copying MapOutput with retry on host " + host 
+          + "after " + fetchRetryTimeout + "milliseconds.");
+      
+    }
+  }
   
   /**
    * Do some basic verification on the input received -- Being defensive
@@ -525,7 +646,7 @@ class Fetcher<K,V> extends Thread {
    * @return
    * @throws MalformedURLException
    */
-  private URL getMapOutputURL(MapHost host, List<TaskAttemptID> maps
+  private URL getMapOutputURL(MapHost host, Collection<TaskAttemptID> maps
                               )  throws MalformedURLException {
     // Get the base url
     StringBuffer url = new StringBuffer(host.getBaseUrl());

+ 9 - 8
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/ShuffleSchedulerImpl.java

@@ -48,6 +48,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptID;
 import org.apache.hadoop.mapreduce.TaskID;
 import org.apache.hadoop.mapreduce.task.reduce.MapHost.State;
 import org.apache.hadoop.util.Progress;
+import org.apache.hadoop.util.Time;
 
 @InterfaceAudience.Private
 @InterfaceStability.Unstable
@@ -121,7 +122,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
     this.shuffledMapsCounter = shuffledMapsCounter;
     this.reduceShuffleBytes = reduceShuffleBytes;
     this.failedShuffleCounter = failedShuffleCounter;
-    this.startTime = System.currentTimeMillis();
+    this.startTime = Time.monotonicNow();
     lastProgressTime = startTime;
     referee.start();
     this.maxFailedUniqueFetches = Math.min(totalMaps, 5);
@@ -198,7 +199,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
       totalBytesShuffledTillNow += bytes;
       updateStatus();
       reduceShuffleBytes.increment(bytes);
-      lastProgressTime = System.currentTimeMillis();
+      lastProgressTime = Time.monotonicNow();
       LOG.debug("map " + mapId + " done " + status.getStateString());
     }
   }
@@ -206,7 +207,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
   private void updateStatus() {
     float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
     int mapsDone = totalMaps - remainingMaps;
-    long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;
+    long secsSinceStart = (Time.monotonicNow() - startTime) / 1000 + 1;
 
     float transferRate = mbs / secsSinceStart;
     progress.set((float) mapsDone / totalMaps);
@@ -307,7 +308,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
     // check if the reducer is stalled for a long time
     // duration for which the reducer is stalled
     int stallDuration =
-      (int)(System.currentTimeMillis() - lastProgressTime);
+      (int)(Time.monotonicNow() - lastProgressTime);
 
     // duration for which the reducer ran with progress
     int shuffleProgressDuration =
@@ -389,7 +390,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
 
       LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
                " to " + Thread.currentThread().getName());
-      shuffleStart.set(System.currentTimeMillis());
+      shuffleStart.set(Time.monotonicNow());
 
       return host;
   }
@@ -430,7 +431,7 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
       }
     }
     LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
-             (System.currentTimeMillis()-shuffleStart.get()) + "ms");
+             (Time.monotonicNow()-shuffleStart.get()) + "ms");
   }
 
   public synchronized void resetKnownMaps() {
@@ -464,12 +465,12 @@ public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
 
     Penalty(MapHost host, long delay) {
       this.host = host;
-      this.endTime = System.currentTimeMillis() + delay;
+      this.endTime = Time.monotonicNow() + delay;
     }
 
     @Override
     public long getDelay(TimeUnit unit) {
-      long remainingTime = endTime - System.currentTimeMillis();
+      long remainingTime = endTime - Time.monotonicNow();
       return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
     }
 

+ 21 - 0
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/resources/mapred-default.xml

@@ -339,6 +339,27 @@
   </description>
 </property>
 
+<property>
+  <name>mapreduce.reduce.shuffle.fetch.retry.enabled</name>
+  <value>${yarn.nodemanager.recovery.enabled}</value>
+  <description>Set to enable fetch retry during host restart.</description>
+</property>
+
+<property>
+  <name>mapreduce.reduce.shuffle.fetch.retry.interval-ms</name>
+  <value>1000</value>
+  <description>Time of interval that fetcher retry to fetch again when some 
+  non-fatal failure happens because of some events like NM restart.
+  </description>
+</property>
+
+<property>
+  <name>mapreduce.reduce.shuffle.fetch.retry.timeout-ms</name>
+  <value>30000</value>
+  <description>Timeout value for fetcher to retry to fetch again when some 
+  non-fatal failure happens because of some events like NM restart.</description>
+</property>
+
 <property>
   <name>mapreduce.reduce.shuffle.retry-delay.max.ms</name>
   <value>60000</value>

+ 95 - 1
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java

@@ -27,6 +27,7 @@ import java.net.HttpURLConnection;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.mapred.MapOutputFile;
+import org.apache.hadoop.mapreduce.MRJobConfig;
 import org.apache.hadoop.mapreduce.TaskID;
 
 import org.junit.After;
@@ -60,6 +61,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptID;
 import org.apache.hadoop.mapreduce.security.SecureShuffleUtils;
 import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager;
 import org.apache.hadoop.util.DiskChecker.DiskErrorException;
+import org.apache.hadoop.util.Time;
 import org.junit.Test;
 
 import org.mockito.invocation.InvocationOnMock;
@@ -71,6 +73,7 @@ import org.mockito.stubbing.Answer;
 public class TestFetcher {
   private static final Log LOG = LogFactory.getLog(TestFetcher.class);
   JobConf job = null;
+  JobConf jobWithRetry = null;
   TaskAttemptID id = null;
   ShuffleSchedulerImpl<Text, Text> ss = null;
   MergeManagerImpl<Text, Text> mm = null;
@@ -93,6 +96,9 @@ public class TestFetcher {
   public void setup() {
     LOG.info(">>>> " + name.getMethodName());
     job = new JobConf();
+    job.setBoolean(MRJobConfig.SHUFFLE_FETCH_RETRY_ENABLED, false);
+    jobWithRetry = new JobConf();
+    jobWithRetry.setBoolean(MRJobConfig.SHUFFLE_FETCH_RETRY_ENABLED, true);
     id = TaskAttemptID.forName("attempt_0_1_r_1_1");
     ss = mock(ShuffleSchedulerImpl.class);
     mm = mock(MergeManagerImpl.class);
@@ -228,6 +234,38 @@ public class TestFetcher {
     verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map1ID));
     verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map2ID));
   }
+  
+  @Test
+  public void testCopyFromHostIncompatibleShuffleVersionWithRetry()
+      throws Exception {
+    String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key);
+    
+    when(connection.getResponseCode()).thenReturn(200);
+    when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
+        .thenReturn("mapreduce").thenReturn("other").thenReturn("other");
+    when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))
+        .thenReturn("1.0.1").thenReturn("1.0.0").thenReturn("1.0.1");
+    when(connection.getHeaderField(
+        SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash);
+    ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]);
+    when(connection.getInputStream()).thenReturn(in);
+
+    for (int i = 0; i < 3; ++i) {
+      Fetcher<Text,Text> underTest = new FakeFetcher<Text,Text>(jobWithRetry, 
+          id, ss, mm, r, metrics, except, key, connection);
+      underTest.copyFromHost(host);
+    }
+    
+    verify(connection, times(3)).addRequestProperty(
+        SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
+    
+    verify(allErrs, times(3)).increment(1);
+    verify(ss, times(3)).copyFailed(map1ID, host, false, false);
+    verify(ss, times(3)).copyFailed(map2ID, host, false, false);
+    
+    verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map1ID));
+    verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map2ID));
+  }
 
   @Test
   public void testCopyFromHostWait() throws Exception {
@@ -301,6 +339,48 @@ public class TestFetcher {
           encHash);
     verify(ss, times(1)).copyFailed(map1ID, host, true, false);
   }
+  
+  @SuppressWarnings("unchecked")
+  @Test(timeout=10000) 
+  public void testCopyFromHostWithRetry() throws Exception {
+    InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class);
+    ss = mock(ShuffleSchedulerImpl.class);
+    Fetcher<Text,Text> underTest = new FakeFetcher<Text,Text>(jobWithRetry, 
+        id, ss, mm, r, metrics, except, key, connection, true);
+
+    String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key);
+    
+    when(connection.getResponseCode()).thenReturn(200);
+    when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH))
+        .thenReturn(replyHash);
+    ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1);
+    ByteArrayOutputStream bout = new ByteArrayOutputStream();
+    header.write(new DataOutputStream(bout));
+    ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray());
+    when(connection.getInputStream()).thenReturn(in);
+    when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
+        .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+    when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))
+        .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+    when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt()))
+        .thenReturn(immo);
+    
+    final long retryTime = Time.monotonicNow();
+    doAnswer(new Answer<Void>() {
+      public Void answer(InvocationOnMock ignore) throws IOException {
+        // Emulate host down for 3 seconds.
+        if ((Time.monotonicNow() - retryTime) <= 3000) {
+          throw new java.lang.InternalError();
+        }
+        return null;
+      }
+    }).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), 
+        anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class));
+
+    underTest.copyFromHost(host);
+    verify(ss, never()).copyFailed(any(TaskAttemptID.class),any(MapHost.class),
+                                   anyBoolean(), anyBoolean());
+  }
 
   @Test
   public void testCopyFromHostExtraBytes() throws Exception {
@@ -447,6 +527,9 @@ public class TestFetcher {
 
   public static class FakeFetcher<K,V> extends Fetcher<K,V> {
 
+    // If connection need to be reopen.
+    private boolean renewConnection = false;
+    
     public FakeFetcher(JobConf job, TaskAttemptID reduceId,
         ShuffleSchedulerImpl<K,V> scheduler, MergeManagerImpl<K,V> merger,
         Reporter reporter, ShuffleClientMetrics metrics,
@@ -456,6 +539,17 @@ public class TestFetcher {
           exceptionReporter, jobTokenSecret);
       this.connection = connection;
     }
+    
+    public FakeFetcher(JobConf job, TaskAttemptID reduceId,
+        ShuffleSchedulerImpl<K,V> scheduler, MergeManagerImpl<K,V> merger,
+        Reporter reporter, ShuffleClientMetrics metrics,
+        ExceptionReporter exceptionReporter, SecretKey jobTokenSecret,
+        HttpURLConnection connection, boolean renewConnection) {
+      super(job, reduceId, scheduler, merger, reporter, metrics,
+          exceptionReporter, jobTokenSecret);
+      this.connection = connection;
+      this.renewConnection = renewConnection;
+    }
 
     public FakeFetcher(JobConf job, TaskAttemptID reduceId,
         ShuffleSchedulerImpl<K,V> scheduler, MergeManagerImpl<K,V> merger,
@@ -469,7 +563,7 @@ public class TestFetcher {
 
     @Override
     protected void openConnection(URL url) throws IOException {
-      if (null == connection) {
+      if (null == connection || renewConnection) {
         super.openConnection(url);
       }
       // already 'opened' the mocked connection