Browse Source

HADOOP-11604. Prevent ConcurrentModificationException while closing domain sockets during shutdown of DomainSocketWatcher thread. Contributed by Chris Nauroth.

(cherry picked from commit 3c5ff0759c4f4e10c97c6d9036add00edb8be2b5)
(cherry picked from commit 187e081d5a8afe1ddfe5d7b5e7de7a94512aa53e)
cnauroth 10 years ago
parent
commit
342504f790

+ 3 - 0
hadoop-common-project/hadoop-common/CHANGES.txt

@@ -42,6 +42,9 @@ Release 2.6.1 - UNRELEASED
     HADOOP-11295. RPC Server Reader thread can't shutdown if RPCCallQueue is
     full. (Ming Ma via kihwal)
 
+    HADOOP-11604. Prevent ConcurrentModificationException while closing domain
+    sockets during shutdown of DomainSocketWatcher thread. (cnauroth)
+
 Release 2.6.0 - 2014-11-18
 
   INCOMPATIBLE CHANGES

+ 41 - 4
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/net/unix/DomainSocketWatcher.java

@@ -246,6 +246,13 @@ public final class DomainSocketWatcher implements Closeable {
     this.interruptCheckPeriodMs = interruptCheckPeriodMs;
     notificationSockets = DomainSocket.socketpair();
     watcherThread.setDaemon(true);
+    watcherThread.setUncaughtExceptionHandler(
+        new Thread.UncaughtExceptionHandler() {
+          @Override
+          public void uncaughtException(Thread thread, Throwable t) {
+            LOG.error(thread + " terminating on unexpected exception", t);
+          }
+        });
     watcherThread.start();
   }
 
@@ -372,7 +379,17 @@ public final class DomainSocketWatcher implements Closeable {
     }
   }
 
-  private void sendCallback(String caller, TreeMap<Integer, Entry> entries,
+  /**
+   * Send callback and return whether or not the domain socket was closed as a
+   * result of processing.
+   *
+   * @param caller reason for call
+   * @param entries mapping of file descriptor to entry
+   * @param fdSet set of file descriptors
+   * @param fd file descriptor
+   * @return true if the domain socket was closed as a result of processing
+   */
+  private boolean sendCallback(String caller, TreeMap<Integer, Entry> entries,
       FdSet fdSet, int fd) {
     if (LOG.isTraceEnabled()) {
       LOG.trace(this + ": " + caller + " starting sendCallback for fd " + fd);
@@ -401,13 +418,30 @@ public final class DomainSocketWatcher implements Closeable {
             "still in the poll(2) loop.");
       }
       IOUtils.cleanup(LOG, sock);
-      entries.remove(fd);
       fdSet.remove(fd);
+      return true;
     } else {
       if (LOG.isTraceEnabled()) {
         LOG.trace(this + ": " + caller + ": sendCallback not " +
             "closing fd " + fd);
       }
+      return false;
+    }
+  }
+
+  /**
+   * Send callback, and if the domain socket was closed as a result of
+   * processing, then also remove the entry for the file descriptor.
+   *
+   * @param caller reason for call
+   * @param entries mapping of file descriptor to entry
+   * @param fdSet set of file descriptors
+   * @param fd file descriptor
+   */
+  private void sendCallbackAndRemove(String caller,
+      TreeMap<Integer, Entry> entries, FdSet fdSet, int fd) {
+    if (sendCallback(caller, entries, fdSet, fd)) {
+      entries.remove(fd);
     }
   }
 
@@ -427,7 +461,8 @@ public final class DomainSocketWatcher implements Closeable {
           lock.lock();
           try {
             for (int fd : fdSet.getAndClearReadableFds()) {
-              sendCallback("getAndClearReadableFds", entries, fdSet, fd);
+              sendCallbackAndRemove("getAndClearReadableFds", entries, fdSet,
+                  fd);
             }
             if (!(toAdd.isEmpty() && toRemove.isEmpty())) {
               // Handle pending additions (before pending removes).
@@ -448,7 +483,7 @@ public final class DomainSocketWatcher implements Closeable {
               while (true) {
                 Map.Entry<Integer, DomainSocket> entry = toRemove.firstEntry();
                 if (entry == null) break;
-                sendCallback("handlePendingRemovals",
+                sendCallbackAndRemove("handlePendingRemovals",
                     entries, fdSet, entry.getValue().fd);
               }
               processedCond.signalAll();
@@ -482,6 +517,8 @@ public final class DomainSocketWatcher implements Closeable {
         try {
           kick(); // allow the handler for notificationSockets[0] to read a byte
           for (Entry entry : entries.values()) {
+            // We do not remove from entries as we iterate, because that can
+            // cause a ConcurrentModificationException.
             sendCallback("close", entries, fdSet, entry.getDomainSocket().fd);
           }
           entries.clear();

+ 61 - 4
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/unix/TestDomainSocketWatcher.java

@@ -17,6 +17,8 @@
  */
 package org.apache.hadoop.net.unix;
 
+import static org.junit.Assert.assertFalse;
+
 import java.util.ArrayList;
 import java.util.Random;
 import java.util.concurrent.CountDownLatch;
@@ -25,6 +27,7 @@ import java.util.concurrent.locks.ReentrantLock;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.junit.After;
 import org.junit.Assume;
 import org.junit.Before;
 import org.junit.Test;
@@ -34,17 +37,28 @@ import com.google.common.util.concurrent.Uninterruptibles;
 public class TestDomainSocketWatcher {
   static final Log LOG = LogFactory.getLog(TestDomainSocketWatcher.class);
 
+  private Throwable trappedException = null;
+
   @Before
   public void before() {
     Assume.assumeTrue(DomainSocket.getLoadingFailureReason() == null);
   }
 
+  @After
+  public void after() {
+    if (trappedException != null) {
+      throw new IllegalStateException(
+          "DomainSocketWatcher thread terminated with unexpected exception.",
+          trappedException);
+    }
+  }
+
   /**
    * Test that we can create a DomainSocketWatcher and then shut it down.
    */
   @Test(timeout=60000)
   public void testCreateShutdown() throws Exception {
-    DomainSocketWatcher watcher = new DomainSocketWatcher(10000000);
+    DomainSocketWatcher watcher = newDomainSocketWatcher(10000000);
     watcher.close();
   }
 
@@ -53,7 +67,7 @@ public class TestDomainSocketWatcher {
    */
   @Test(timeout=180000)
   public void testDeliverNotifications() throws Exception {
-    DomainSocketWatcher watcher = new DomainSocketWatcher(10000000);
+    DomainSocketWatcher watcher = newDomainSocketWatcher(10000000);
     DomainSocket pair[] = DomainSocket.socketpair();
     final CountDownLatch latch = new CountDownLatch(1);
     watcher.add(pair[1], new DomainSocketWatcher.Handler() {
@@ -73,17 +87,35 @@ public class TestDomainSocketWatcher {
    */
   @Test(timeout=60000)
   public void testInterruption() throws Exception {
-    final DomainSocketWatcher watcher = new DomainSocketWatcher(10);
+    final DomainSocketWatcher watcher = newDomainSocketWatcher(10);
     watcher.watcherThread.interrupt();
     Uninterruptibles.joinUninterruptibly(watcher.watcherThread);
     watcher.close();
   }
+
+  /**
+   * Test that domain sockets are closed when the watcher is closed.
+   */
+  @Test(timeout=300000)
+  public void testCloseSocketOnWatcherClose() throws Exception {
+    final DomainSocketWatcher watcher = newDomainSocketWatcher(10000000);
+    DomainSocket pair[] = DomainSocket.socketpair();
+    watcher.add(pair[1], new DomainSocketWatcher.Handler() {
+      @Override
+      public boolean handle(DomainSocket sock) {
+        return true;
+      }
+    });
+    watcher.close();
+    Uninterruptibles.joinUninterruptibly(watcher.watcherThread);
+    assertFalse(pair[1].isOpen());
+  }
   
   @Test(timeout=300000)
   public void testStress() throws Exception {
     final int SOCKET_NUM = 250;
     final ReentrantLock lock = new ReentrantLock();
-    final DomainSocketWatcher watcher = new DomainSocketWatcher(10000000);
+    final DomainSocketWatcher watcher = newDomainSocketWatcher(10000000);
     final ArrayList<DomainSocket[]> pairs = new ArrayList<DomainSocket[]>();
     final AtomicInteger handled = new AtomicInteger(0);
 
@@ -148,4 +180,29 @@ public class TestDomainSocketWatcher {
     Uninterruptibles.joinUninterruptibly(removerThread);
     watcher.close();
   }
+
+  /**
+   * Creates a new DomainSocketWatcher and tracks its thread for termination due
+   * to an unexpected exception.  At the end of each test, if there was an
+   * unexpected exception, then that exception is thrown to force a failure of
+   * the test.
+   *
+   * @param interruptCheckPeriodMs interrupt check period passed to
+   *     DomainSocketWatcher
+   * @return new DomainSocketWatcher
+   * @throws Exception if there is any failure
+   */
+  private DomainSocketWatcher newDomainSocketWatcher(int interruptCheckPeriodMs)
+      throws Exception {
+    DomainSocketWatcher watcher = new DomainSocketWatcher(
+        interruptCheckPeriodMs);
+    watcher.watcherThread.setUncaughtExceptionHandler(
+        new Thread.UncaughtExceptionHandler() {
+          @Override
+          public void uncaughtException(Thread thread, Throwable t) {
+            trappedException = t;
+          }
+        });
+    return watcher;
+  }
 }