Browse Source

MAPREDUCE-7503. Fix ByteBuf leaks in TestShuffleChannelHandler (#7500) Contributed by Istvan Toth.

Signed-off-by: Shilun Fan <slfan1989@apache.org>
Istvan Toth 1 month ago
parent
commit
cd8f18b71d

+ 35 - 2
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleChannelHandler.java

@@ -44,6 +44,7 @@ import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
 import io.netty.handler.ssl.SslHandler;
 import io.netty.handler.stream.ChunkedWriteHandler;
+import io.netty.util.ReferenceCounted;
 import io.netty.util.concurrent.GlobalEventExecutor;
 
 import java.io.ByteArrayOutputStream;
@@ -61,6 +62,7 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
 import javax.crypto.SecretKey;
@@ -115,6 +117,7 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
     final LinkedList<Object> unencryptedMessages = new LinkedList<>();
     final EmbeddedChannel shuffle = t.createShuffleHandlerSSL(unencryptedMessages);
     t.testGetAllAttemptsForReduce0NoKeepAlive(unencryptedMessages, shuffle);
+    drainChannel(shuffle);
   }
 
   @Test
@@ -192,8 +195,10 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
 
     assertEquals(getExpectedHttpResponse(HttpResponseStatus.BAD_REQUEST).toString(),
         actual.toString());
+    tryRelease(actual);
 
     assertFalse(shuffle.isActive(), "closed"); // known-issue
+    drainChannel(decoder);
   }
 
   @Test
@@ -210,11 +215,13 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
     }
 
     DefaultHttpResponse actual = decoder.readInbound();
+    drainChannel(decoder);
     assertFalse(actual.headers().get(CONTENT_LENGTH).isEmpty());
     actual.headers().set(CONTENT_LENGTH, 0);
 
     assertEquals(getExpectedHttpResponse(HttpResponseStatus.INTERNAL_SERVER_ERROR).toString(),
         actual.toString());
+    tryRelease(actual);
 
     assertFalse(shuffle.isActive(), "closed");
   }
@@ -237,15 +244,36 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
     }
 
     DefaultHttpResponse actual = decoder.readInbound();
+    drainChannel(decoder);
     assertFalse(actual.headers().get(CONTENT_LENGTH).isEmpty());
     actual.headers().set(CONTENT_LENGTH, 0);
 
     assertEquals(getExpectedHttpResponse(HttpResponseStatus.INTERNAL_SERVER_ERROR).toString(),
         actual.toString());
+    tryRelease(actual);
 
     assertFalse(shuffle.isActive(), "closed");
   }
 
+  private void drainChannel(EmbeddedChannel ch) {
+    Object o;
+    while((o = ch.readInbound())!=null) {
+      tryRelease(o);
+    }
+    while((o = ch.readOutbound())!=null) {
+      tryRelease(o);
+    }
+  }
+
+  private void tryRelease(Object obj) {
+    if (obj instanceof ReferenceCounted) {
+      ReferenceCounted bb = (ReferenceCounted) obj;
+      if (bb.refCnt() > 0) {
+        bb.release(bb.refCnt());
+      }
+    }
+  }
+
   private DefaultHttpResponse getExpectedHttpResponse(HttpResponseStatus status) {
     DefaultHttpResponse response = new DefaultHttpResponse(HTTP_1_1, status);
     response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8");
@@ -365,8 +393,8 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
       assertFalse(shuffle.isActive(), "no keep-alive");
     }
 
-    private void testKeepAlive(java.util.Queue<Object> messages,
-                               EmbeddedChannel shuffle) throws IOException {
+    private void testKeepAlive(java.util.Queue<Object> messages, EmbeddedChannel shuffle)
+        throws IOException, InterruptedException, ExecutionException {
       final FullHttpRequest req1 = createRequest(
           getUri(TEST_JOB_ID, 0, Collections.singletonList(TEST_ATTEMPT_1), true));
       shuffle.writeInbound(req1);
@@ -375,6 +403,7 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
           getAttemptData(new Attempt(TEST_ATTEMPT_1, TEST_DATA_A))
       );
       assertTrue(shuffle.isActive(), "keep-alive");
+      drainChannel(shuffle);
       messages.clear();
 
       final FullHttpRequest req2 = createRequest(
@@ -385,6 +414,7 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
           getAttemptData(new Attempt(TEST_ATTEMPT_2, TEST_DATA_B))
       );
       assertTrue(shuffle.isActive(), "keep-alive");
+      drainChannel(shuffle);
       messages.clear();
 
       final FullHttpRequest req3 = createRequest(
@@ -395,6 +425,7 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
           getAttemptData(new Attempt(TEST_ATTEMPT_3, TEST_DATA_C))
       );
       assertFalse(shuffle.isActive(), "no keep-alive");
+      drainChannel(shuffle);
     }
 
     private ArrayList<ByteBuf> getAllAttemptsForReduce0() throws IOException {
@@ -431,11 +462,13 @@ public class TestShuffleChannelHandler extends TestShuffleHandlerBase {
         decodeChannel.writeInbound(actualBytes);
         Object obj = decodeChannel.readInbound();
         LOG.info("Decoded object: {}", obj);
+        drainChannel(decodeChannel);
 
         if (i == 0) {
           DefaultHttpResponse resp = (DefaultHttpResponse) obj;
           assertEquals(response.toString(), resp.toString());
         }
+        tryRelease(obj);
         if (i > 0 && i <= content.size()) {
           assertEquals(ByteBufUtil.prettyHexDump(content.get(i - 1)),
               actualHexdump, "data should match");

+ 7 - 0
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandlerBase.java

@@ -84,6 +84,13 @@ public class TestShuffleHandlerBase {
 
   @AfterEach
   public void teardown() {
+    //Trigger GC so that we get the leak warnings early
+    System.gc();
+    try {
+      // Wait for logger to flush
+      Thread.sleep(1000);
+    } catch (InterruptedException e) {
+    }
     System.setOut(standardOut);
     System.out.print(outputStreamCaptor);
     // For this to work ch.qos.logback.classic is needed for some reason