فهرست منبع

HADOOP-18426. Use weighted calculation for MutableStat mean/variance to fix accuracy. (#4844). Contributed by Erik Krogen.

Co-authored-by: Shuyan Zhang <zqingchai@gmail.com>
Signed-off-by: He Xiaoqiao <hexiaoqiao@apache.org>
Erik Krogen 2 سال پیش
والد
کامیت
c664f953c9

+ 26 - 32
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/metrics2/util/SampleStat.java

@@ -27,33 +27,29 @@ import org.apache.hadoop.classification.InterfaceAudience;
 public class SampleStat {
   private final MinMax minmax = new MinMax();
   private long numSamples = 0;
-  private double a0, a1, s0, s1, total;
+  private double mean, s;
 
   /**
    * Construct a new running sample stat
    */
   public SampleStat() {
-    a0 = s0 = 0.0;
-    total = 0.0;
+    mean = 0.0;
+    s = 0.0;
   }
 
   public void reset() {
     numSamples = 0;
-    a0 = s0 = 0.0;
-    total = 0.0;
+    mean = 0.0;
+    s = 0.0;
     minmax.reset();
   }
 
   // We want to reuse the object, sometimes.
-  void reset(long numSamples, double a0, double a1, double s0, double s1,
-      double total, MinMax minmax) {
-    this.numSamples = numSamples;
-    this.a0 = a0;
-    this.a1 = a1;
-    this.s0 = s0;
-    this.s1 = s1;
-    this.total = total;
-    this.minmax.reset(minmax);
+  void reset(long numSamples1, double mean1, double s1, MinMax minmax1) {
+    numSamples = numSamples1;
+    mean = mean1;
+    s = s1;
+    minmax.reset(minmax1);
   }
 
   /**
@@ -61,7 +57,7 @@ public class SampleStat {
    * @param other the destination to hold our values
    */
   public void copyTo(SampleStat other) {
-    other.reset(numSamples, a0, a1, s0, s1, total, minmax);
+    other.reset(numSamples, mean, s, minmax);
   }
 
   /**
@@ -78,24 +74,22 @@ public class SampleStat {
    * Add some sample and a partial sum to the running stat.
    * Note, min/max is not evaluated using this method.
    * @param nSamples  number of samples
-   * @param x the partial sum
+   * @param xTotal the partial sum
    * @return  self
    */
-  public SampleStat add(long nSamples, double x) {
+  public SampleStat add(long nSamples, double xTotal) {
     numSamples += nSamples;
-    total += x;
 
-    if (numSamples == 1) {
-      a0 = a1 = x;
-      s0 = 0.0;
-    }
-    else {
-      // The Welford method for numerical stability
-      a1 = a0 + (x - a0) / numSamples;
-      s1 = s0 + (x - a0) * (x - a1);
-      a0 = a1;
-      s0 = s1;
-    }
+    // use the weighted incremental version of Welford's algorithm to get
+    // numerical stability while treating the samples as being weighted
+    // by nSamples
+    // see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+
+    double x = xTotal / nSamples;
+    double meanOld = mean;
+
+    mean += ((double) nSamples / numSamples) * (x - meanOld);
+    s += nSamples * (x - meanOld) * (x - mean);
     return this;
   }
 
@@ -110,21 +104,21 @@ public class SampleStat {
    * @return the total of all samples added
    */
   public double total() {
-    return total;
+    return mean * numSamples;
   }
 
   /**
    * @return  the arithmetic mean of the samples
    */
   public double mean() {
-    return numSamples > 0 ? (total / numSamples) : 0.0;
+    return numSamples > 0 ? mean : 0.0;
   }
 
   /**
    * @return  the variance of the samples
    */
   public double variance() {
-    return numSamples > 1 ? s1 / (numSamples - 1) : 0.0;
+    return numSamples > 1 ? s / (numSamples - 1) : 0.0;
   }
 
   /**

+ 48 - 8
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/metrics2/lib/TestMutableMetrics.java

@@ -29,6 +29,8 @@ import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.junit.Assert.*;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Random;
@@ -36,6 +38,7 @@ import java.util.concurrent.CountDownLatch;
 
 import org.apache.hadoop.metrics2.MetricsRecordBuilder;
 import org.apache.hadoop.metrics2.util.Quantile;
+import org.apache.hadoop.thirdparty.com.google.common.math.Stats;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -47,7 +50,7 @@ public class TestMutableMetrics {
 
   private static final Logger LOG =
       LoggerFactory.getLogger(TestMutableMetrics.class);
-  private final double EPSILON = 1e-42;
+  private static final double EPSILON = 1e-42;
 
   /**
    * Test the snapshot method
@@ -306,19 +309,56 @@ public class TestMutableMetrics {
 
   /**
    * Tests that when using {@link MutableStat#add(long, long)}, even with a high
-   * sample count, the mean does not lose accuracy.
+   * sample count, the mean does not lose accuracy. This also validates that
+   * the std dev is correct, assuming samples of equal value.
    */
-  @Test public void testMutableStatWithBulkAdd() {
+  @Test
+  public void testMutableStatWithBulkAdd() {
+    List<Long> samples = new ArrayList<>();
+    for (int i = 0; i < 1000; i++) {
+      samples.add(1000L);
+    }
+    for (int i = 0; i < 1000; i++) {
+      samples.add(2000L);
+    }
+    Stats stats = Stats.of(samples);
+
+    for (int bulkSize : new int[] {1, 10, 100, 1000}) {
+      MetricsRecordBuilder rb = mockMetricsRecordBuilder();
+      MetricsRegistry registry = new MetricsRegistry("test");
+      MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
+
+      for (int i = 0; i < samples.size(); i += bulkSize) {
+        stat.add(bulkSize, samples
+            .subList(i, i + bulkSize)
+            .stream()
+            .mapToLong(Long::longValue)
+            .sum()
+        );
+      }
+      registry.snapshot(rb, false);
+
+      assertCounter("TestNumOps", 2000L, rb);
+      assertGauge("TestAvgVal", stats.mean(), rb);
+      assertGauge("TestStdevVal", stats.sampleStandardDeviation(), rb);
+    }
+  }
+
+  @Test
+  public void testLargeMutableStatAdd() {
     MetricsRecordBuilder rb = mockMetricsRecordBuilder();
     MetricsRegistry registry = new MetricsRegistry("test");
-    MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", false);
+    MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
 
-    stat.add(1000, 1000);
-    stat.add(1000, 2000);
+    long sample = 1000000000000009L;
+    for (int i = 0; i < 100; i++) {
+      stat.add(1, sample);
+    }
     registry.snapshot(rb, false);
 
-    assertCounter("TestNumOps", 2000L, rb);
-    assertGauge("TestAvgVal", 1.5, rb);
+    assertCounter("TestNumOps", 100L, rb);
+    assertGauge("TestAvgVal", (double) sample, rb);
+    assertGauge("TestStdevVal", 0.0, rb);
   }
 
   /**