瀏覽代碼

YARN-11785. Race condition in QueueMetrics due to non-thread-safe HashMap causes MetricsException. (#7459) Contributed by Tao Yang.

* YARN-11785. Race condition in QueueMetrics due to non-thread-safe HashMap causes MetricsException.

Signed-off-by: Shilun Fan <slfan1989@apache.org>
Tao Yang 2 月之前
父節點
當前提交
6b561d5467

+ 2 - 1
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-resourcemanager/src/main/java/org/apache/hadoop/yarn/server/resourcemanager/scheduler/QueueMetrics.java

@@ -24,6 +24,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
@@ -253,7 +254,7 @@ public class QueueMetrics implements MetricsSource {
    * Simple metrics cache to help prevent re-registrations.
    */
   private static final Map<String, QueueMetrics> QUEUE_METRICS =
-      new HashMap<String, QueueMetrics>();
+      new ConcurrentHashMap<>();
 
   /**
    * Returns the metrics cache to help prevent re-registrations.

+ 61 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-resourcemanager/src/test/java/org/apache/hadoop/yarn/server/resourcemanager/scheduler/TestQueueMetrics.java

@@ -37,6 +37,9 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
 import static org.apache.hadoop.test.MetricsAsserts.assertCounter;
 import static org.apache.hadoop.test.MetricsAsserts.getMetrics;
 import static org.apache.hadoop.yarn.server.resourcemanager.scheduler.AppMetricsChecker.AppMetricsKey.APPS_COMPLETED;
@@ -61,6 +64,7 @@ import static org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceMe
 import static org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceMetricsChecker.ResourceMetricsKey.RESERVED_CONTAINERS;
 import static org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceMetricsChecker.ResourceMetricsKey.RESERVED_MB;
 import static org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceMetricsChecker.ResourceMetricsKey.RESERVED_V_CORES;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -794,6 +798,63 @@ public class TestQueueMetrics {
         .checkAgainst(queueSource, true);
   }
 
+  @Test
+  public void testQueueMetricsRaceCondition() throws InterruptedException {
+    final CountDownLatch latch = new CountDownLatch(2);
+    final int numIterations = 100000;
+    final AtomicInteger exceptionCount = new AtomicInteger(0);
+    final AtomicInteger getCount = new AtomicInteger(0);
+
+    // init a queue metrics for testing
+    String queueName = "test";
+    QueueMetrics metrics =
+        QueueMetrics.forQueue(ms, queueName, null, false, conf);
+    QueueMetrics.getQueueMetrics().put(queueName, metrics);
+
+    /*
+     * simulate the concurrent calls for QueueMetrics#getQueueMetrics
+     */
+    // thread A will keep querying the same queue metrics for a specified number of iterations
+    Thread threadA = new Thread(() -> {
+      try {
+        for (int i = 0; i < numIterations; i++) {
+          QueueMetrics qm = QueueMetrics.getQueueMetrics().get(queueName);
+          if (qm != null) {
+            getCount.incrementAndGet();
+          }
+        }
+      } catch (Exception e) {
+        System.out.println("Exception: " + e.getMessage());
+        exceptionCount.incrementAndGet();
+      } finally {
+        latch.countDown();
+      }
+    });
+    // thread B will keep adding new queue metrics for a specified number of iterations
+    Thread threadB = new Thread(() -> {
+      try {
+        for (int i = 0; i < numIterations; i++) {
+          QueueMetrics.getQueueMetrics().put("q" + i, metrics);
+        }
+      } catch (Exception e) {
+        exceptionCount.incrementAndGet();
+      } finally {
+        latch.countDown();
+      }
+    });
+
+    // start threads and wait for them to finish
+    threadA.start();
+    threadB.start();
+    latch.await();
+
+    // check if all get operations are successful to
+    // make sure there is no race condition
+    assertEquals(numIterations, getCount.get());
+    // check if there is any exception
+    assertEquals(0, exceptionCount.get());
+  }
+
   private static void checkAggregatedNodeTypes(MetricsSource source,
       long nodeLocal, long rackLocal, long offSwitch) {
     MetricsRecordBuilder rb = getMetrics(source);