Browse Source

HDFS-14088. RequestHedgingProxyProvider can throw NullPointerException when failover due to no lock on currentUsedProxy. Contributed by Yuxuan Wang.

Giovanni Matteo Fumarola 6 years ago
parent
commit
f858f18554

+ 114 - 95
hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/server/namenode/ha/RequestHedgingProxyProvider.java

@@ -58,6 +58,8 @@ public class RequestHedgingProxyProvider<T> extends
   class RequestHedgingInvocationHandler implements InvocationHandler {
   class RequestHedgingInvocationHandler implements InvocationHandler {
 
 
     final Map<String, ProxyInfo<T>> targetProxies;
     final Map<String, ProxyInfo<T>> targetProxies;
+    // Proxy of the active nn
+    private volatile ProxyInfo<T> currentUsedProxy = null;
 
 
     public RequestHedgingInvocationHandler(
     public RequestHedgingInvocationHandler(
             Map<String, ProxyInfo<T>> targetProxies) {
             Map<String, ProxyInfo<T>> targetProxies) {
@@ -79,104 +81,118 @@ public class RequestHedgingProxyProvider<T> extends
     public Object
     public Object
     invoke(Object proxy, final Method method, final Object[] args)
     invoke(Object proxy, final Method method, final Object[] args)
             throws Throwable {
             throws Throwable {
-      if (currentUsedProxy != null) {
-        try {
-          Object retVal = method.invoke(currentUsedProxy.proxy, args);
-          LOG.debug("Invocation successful on [{}]",
-              currentUsedProxy.proxyInfo);
-          return retVal;
-        } catch (InvocationTargetException ex) {
-          Exception unwrappedException = unwrapInvocationTargetException(ex);
-          logProxyException(unwrappedException, currentUsedProxy.proxyInfo);
-          LOG.trace("Unsuccessful invocation on [{}]",
-              currentUsedProxy.proxyInfo);
-          throw unwrappedException;
-        }
-      }
-      Map<Future<Object>, ProxyInfo<T>> proxyMap = new HashMap<>();
-      int numAttempts = 0;
+      // Need double check locking to guarantee thread-safe since
+      // currentUsedProxy is lazily initialized.
+      if (currentUsedProxy == null) {
+        synchronized (this) {
+          if (currentUsedProxy == null) {
+            Map<Future<Object>, ProxyInfo<T>> proxyMap = new HashMap<>();
+            int numAttempts = 0;
 
 
-      ExecutorService executor = null;
-      CompletionService<Object> completionService;
-      try {
-        // Optimization : if only 2 proxies are configured and one had failed
-        // over, then we dont need to create a threadpool etc.
-        targetProxies.remove(toIgnore);
-        if (targetProxies.size() == 0) {
-          LOG.trace("No valid proxies left");
-          throw new RemoteException(IOException.class.getName(),
-              "No valid proxies left. All NameNode proxies have failed over.");
-        }
-        if (targetProxies.size() == 1) {
-          ProxyInfo<T> proxyInfo = targetProxies.values().iterator().next();
-          try {
-            currentUsedProxy = proxyInfo;
-            Object retVal = method.invoke(proxyInfo.proxy, args);
-            LOG.debug("Invocation successful on [{}]",
-                currentUsedProxy.proxyInfo);
-            return retVal;
-          } catch (InvocationTargetException ex) {
-            Exception unwrappedException = unwrapInvocationTargetException(ex);
-            logProxyException(unwrappedException, currentUsedProxy.proxyInfo);
-            LOG.trace("Unsuccessful invocation on [{}]",
-                currentUsedProxy.proxyInfo);
-            throw unwrappedException;
-          }
-        }
-        executor = Executors.newFixedThreadPool(proxies.size());
-        completionService = new ExecutorCompletionService<>(executor);
-        for (final Map.Entry<String, ProxyInfo<T>> pEntry :
-                targetProxies.entrySet()) {
-          Callable<Object> c = new Callable<Object>() {
-            @Override
-            public Object call() throws Exception {
-              LOG.trace("Invoking method {} on proxy {}", method,
-                  pEntry.getValue().proxyInfo);
-              return method.invoke(pEntry.getValue().proxy, args);
-            }
-          };
-          proxyMap.put(completionService.submit(c), pEntry.getValue());
-          numAttempts++;
-        }
+            ExecutorService executor = null;
+            CompletionService<Object> completionService;
+            try {
+              // Optimization : if only 2 proxies are configured and one had
+              // failed
+              // over, then we dont need to create a threadpool etc.
+              targetProxies.remove(toIgnore);
+              if (targetProxies.size() == 0) {
+                LOG.trace("No valid proxies left");
+                throw new RemoteException(IOException.class.getName(),
+                    "No valid proxies left. "
+                        + "All NameNode proxies have failed over.");
+              }
+              if (targetProxies.size() == 1) {
+                ProxyInfo<T> proxyInfo =
+                    targetProxies.values().iterator().next();
+                try {
+                  currentUsedProxy = proxyInfo;
+                  Object retVal = method.invoke(proxyInfo.proxy, args);
+                  LOG.debug("Invocation successful on [{}]",
+                      currentUsedProxy.proxyInfo);
+                  return retVal;
+                } catch (InvocationTargetException ex) {
+                  Exception unwrappedException =
+                      unwrapInvocationTargetException(ex);
+                  logProxyException(unwrappedException,
+                      currentUsedProxy.proxyInfo);
+                  LOG.trace("Unsuccessful invocation on [{}]",
+                      currentUsedProxy.proxyInfo);
+                  throw unwrappedException;
+                }
+              }
+              executor = Executors.newFixedThreadPool(proxies.size());
+              completionService = new ExecutorCompletionService<>(executor);
+              for (final Map.Entry<String, ProxyInfo<T>> pEntry : targetProxies
+                  .entrySet()) {
+                Callable<Object> c = new Callable<Object>() {
+                  @Override
+                  public Object call() throws Exception {
+                    LOG.trace("Invoking method {} on proxy {}", method,
+                        pEntry.getValue().proxyInfo);
+                    return method.invoke(pEntry.getValue().proxy, args);
+                  }
+                };
+                proxyMap.put(completionService.submit(c), pEntry.getValue());
+                numAttempts++;
+              }
 
 
-        Map<String, Exception> badResults = new HashMap<>();
-        while (numAttempts > 0) {
-          Future<Object> callResultFuture = completionService.take();
-          Object retVal;
-          try {
-            currentUsedProxy = proxyMap.get(callResultFuture);
-            retVal = callResultFuture.get();
-            LOG.debug("Invocation successful on [{}]",
-                currentUsedProxy.proxyInfo);
-            return retVal;
-          } catch (ExecutionException ex) {
-            Exception unwrappedException = unwrapExecutionException(ex);
-            ProxyInfo<T> tProxyInfo = proxyMap.get(callResultFuture);
-            logProxyException(unwrappedException, tProxyInfo.proxyInfo);
-            badResults.put(tProxyInfo.proxyInfo, unwrappedException);
-            LOG.trace("Unsuccessful invocation on [{}]", tProxyInfo.proxyInfo);
-            numAttempts--;
-          }
-        }
+              Map<String, Exception> badResults = new HashMap<>();
+              while (numAttempts > 0) {
+                Future<Object> callResultFuture = completionService.take();
+                Object retVal;
+                try {
+                  currentUsedProxy = proxyMap.get(callResultFuture);
+                  retVal = callResultFuture.get();
+                  LOG.debug("Invocation successful on [{}]",
+                      currentUsedProxy.proxyInfo);
+                  return retVal;
+                } catch (ExecutionException ex) {
+                  Exception unwrappedException = unwrapExecutionException(ex);
+                  ProxyInfo<T> tProxyInfo = proxyMap.get(callResultFuture);
+                  logProxyException(unwrappedException, tProxyInfo.proxyInfo);
+                  badResults.put(tProxyInfo.proxyInfo, unwrappedException);
+                  LOG.trace("Unsuccessful invocation on [{}]",
+                      tProxyInfo.proxyInfo);
+                  numAttempts--;
+                }
+              }
 
 
-        // At this point we should have All bad results (Exceptions)
-        // Or should have returned with successful result.
-        if (badResults.size() == 1) {
-          throw badResults.values().iterator().next();
-        } else {
-          throw new MultiException(badResults);
-        }
-      } finally {
-        if (executor != null) {
-          LOG.trace("Shutting down threadpool executor");
-          executor.shutdownNow();
+              // At this point we should have All bad results (Exceptions)
+              // Or should have returned with successful result.
+              if (badResults.size() == 1) {
+                throw badResults.values().iterator().next();
+              } else {
+                throw new MultiException(badResults);
+              }
+            } finally {
+              if (executor != null) {
+                LOG.trace("Shutting down threadpool executor");
+                executor.shutdownNow();
+              }
+            }
+          }
         }
         }
       }
       }
+      // Because the above synchronized block will return or throw an exception,
+      // so we don't need to do any check to prevent the first initialized
+      // thread from stepping to following codes.
+      try {
+        Object retVal = method.invoke(currentUsedProxy.proxy, args);
+        LOG.debug("Invocation successful on [{}]", currentUsedProxy.proxyInfo);
+        return retVal;
+      } catch (InvocationTargetException ex) {
+        Exception unwrappedException = unwrapInvocationTargetException(ex);
+        logProxyException(unwrappedException, currentUsedProxy.proxyInfo);
+        LOG.trace("Unsuccessful invocation on [{}]",
+            currentUsedProxy.proxyInfo);
+        throw unwrappedException;
+      }
     }
     }
   }
   }
 
 
-
-  private volatile ProxyInfo<T> currentUsedProxy = null;
+  /** A proxy wrapping {@link RequestHedgingInvocationHandler}. */
+  private ProxyInfo<T> currentUsedHandler = null;
   private volatile String toIgnore = null;
   private volatile String toIgnore = null;
 
 
   public RequestHedgingProxyProvider(Configuration conf, URI uri,
   public RequestHedgingProxyProvider(Configuration conf, URI uri,
@@ -187,8 +203,8 @@ public class RequestHedgingProxyProvider<T> extends
   @SuppressWarnings("unchecked")
   @SuppressWarnings("unchecked")
   @Override
   @Override
   public synchronized ProxyInfo<T> getProxy() {
   public synchronized ProxyInfo<T> getProxy() {
-    if (currentUsedProxy != null) {
-      return currentUsedProxy;
+    if (currentUsedHandler != null) {
+      return currentUsedHandler;
     }
     }
     Map<String, ProxyInfo<T>> targetProxyInfos = new HashMap<>();
     Map<String, ProxyInfo<T>> targetProxyInfos = new HashMap<>();
     StringBuilder combinedInfo = new StringBuilder("[");
     StringBuilder combinedInfo = new StringBuilder("[");
@@ -203,13 +219,16 @@ public class RequestHedgingProxyProvider<T> extends
             RequestHedgingInvocationHandler.class.getClassLoader(),
             RequestHedgingInvocationHandler.class.getClassLoader(),
             new Class<?>[]{xface},
             new Class<?>[]{xface},
             new RequestHedgingInvocationHandler(targetProxyInfos));
             new RequestHedgingInvocationHandler(targetProxyInfos));
-    return new ProxyInfo<T>(wrappedProxy, combinedInfo.toString());
+    currentUsedHandler =
+        new ProxyInfo<T>(wrappedProxy, combinedInfo.toString());
+    return currentUsedHandler;
   }
   }
 
 
   @Override
   @Override
   public synchronized void performFailover(T currentProxy) {
   public synchronized void performFailover(T currentProxy) {
-    toIgnore = this.currentUsedProxy.proxyInfo;
-    this.currentUsedProxy = null;
+    toIgnore = ((RequestHedgingInvocationHandler) Proxy.getInvocationHandler(
+        currentUsedHandler.proxy)).currentUsedProxy.proxyInfo;
+    this.currentUsedHandler = null;
   }
   }
 
 
   /**
   /**

+ 60 - 0
hadoop-hdfs-project/hadoop-hdfs-client/src/test/java/org/apache/hadoop/hdfs/server/namenode/ha/TestRequestHedgingProxyProvider.java

@@ -37,6 +37,7 @@ import org.apache.hadoop.ipc.RemoteException;
 import org.apache.hadoop.ipc.StandbyException;
 import org.apache.hadoop.ipc.StandbyException;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.test.GenericTestUtils;
 import org.apache.hadoop.test.GenericTestUtils;
+import org.apache.hadoop.test.LambdaTestUtils;
 import org.apache.hadoop.util.Time;
 import org.apache.hadoop.util.Time;
 import org.apache.log4j.Level;
 import org.apache.log4j.Level;
 import org.junit.Assert;
 import org.junit.Assert;
@@ -634,6 +635,65 @@ public class TestRequestHedgingProxyProvider {
     Mockito.verify(standby).getStats();
     Mockito.verify(standby).getStats();
   }
   }
 
 
+  /**
+   * HDFS-14088, we first make a successful RPC call, so
+   * RequestHedgingInvocationHandler#currentUsedProxy will be assigned to the
+   * delayMock. Then: <br/>
+   * 1. We start a thread which sleep for 1 sec and call
+   * RequestHedgingProxyProvider#performFailover() <br/>
+   * 2. We make an RPC call again, the call will sleep for 2 sec and throw an
+   * exception for test.<br/>
+   * 3. RequestHedgingInvocationHandler#invoke() will catch the exception and
+   * log RequestHedgingInvocationHandler#currentUsedProxy. Before patch, there
+   * will throw NullPointException.
+   * @throws Exception
+   */
+  @Test
+  public void testHedgingMultiThreads() throws Exception {
+    final AtomicInteger counter = new AtomicInteger(0);
+    final ClientProtocol delayMock = Mockito.mock(ClientProtocol.class);
+    Mockito.when(delayMock.getStats()).thenAnswer(new Answer<long[]>() {
+      @Override
+      public long[] answer(InvocationOnMock invocation) throws Throwable {
+        int flag = counter.incrementAndGet();
+        Thread.sleep(2000);
+        if (flag == 1) {
+          return new long[]{1};
+        } else {
+          throw new IOException("Exception for test.");
+        }
+      }
+    });
+    final ClientProtocol badMock = Mockito.mock(ClientProtocol.class);
+    Mockito.when(badMock.getStats()).thenThrow(new IOException("Bad mock !!"));
+    final RequestHedgingProxyProvider<ClientProtocol> provider =
+        new RequestHedgingProxyProvider<>(conf, nnUri, ClientProtocol.class,
+            createFactory(delayMock, badMock));
+    final ClientProtocol delayProxy = provider.getProxy().proxy;
+    long[] stats = delayProxy.getStats();
+    Assert.assertTrue(stats.length == 1);
+    Assert.assertEquals(1, stats[0]);
+    Assert.assertEquals(1, counter.get());
+
+    Thread t = new Thread() {
+      @Override
+      public void run() {
+        try {
+          // Fail over between calling delayProxy.getStats() and throw
+          // exception.
+          Thread.sleep(1000);
+          provider.performFailover(delayProxy);
+        } catch (Exception e) {
+          e.printStackTrace();
+        }
+      }
+    };
+    t.start();
+    LambdaTestUtils.intercept(IOException.class, "Exception for test.",
+        delayProxy::getStats);
+    t.join();
+  }
+
   private HAProxyFactory<ClientProtocol> createFactory(
   private HAProxyFactory<ClientProtocol> createFactory(
       ClientProtocol... protos) {
       ClientProtocol... protos) {
     final Iterator<ClientProtocol> iterator =
     final Iterator<ClientProtocol> iterator =