|
@@ -31,8 +31,11 @@ import java.io.IOException;
|
|
|
import java.lang.reflect.InvocationHandler;
|
|
|
import java.lang.reflect.Proxy;
|
|
|
import java.net.InetSocketAddress;
|
|
|
+import java.util.concurrent.atomic.AtomicInteger;
|
|
|
|
|
|
-import static org.junit.Assert.*;
|
|
|
+import static org.junit.Assert.assertEquals;
|
|
|
+import static org.junit.Assert.assertNotEquals;
|
|
|
+import static org.junit.Assert.assertTrue;
|
|
|
import static org.mockito.Mockito.any;
|
|
|
import static org.mockito.Mockito.eq;
|
|
|
import static org.mockito.Mockito.times;
|
|
@@ -51,6 +54,8 @@ public class TestRMFailoverProxyProvider {
|
|
|
private static final int RM2_PORT = 8031;
|
|
|
private static final int RM3_PORT = 8033;
|
|
|
|
|
|
+ private static final int NUM_ITERATIONS = 50;
|
|
|
+
|
|
|
private Configuration conf;
|
|
|
|
|
|
private class TestProxy extends Proxy implements Closeable {
|
|
@@ -303,5 +308,73 @@ public class TestRMFailoverProxyProvider {
|
|
|
.getProxy(any(YarnConfiguration.class), any(Class.class),
|
|
|
eq(mockAdd3));
|
|
|
}
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testRandomSelectRouter() throws Exception {
|
|
|
+
|
|
|
+ // We design a test case like this:
|
|
|
+ // We have three routers (router1, router2, and router3),
|
|
|
+ // we enable Federation mode and random selection mode.
|
|
|
+ // After iterating 50 times, since the selection is random,
|
|
|
+ // each router should be selected more than 0 times,
|
|
|
+ // and the sum of the number of times each router is selected should be equal to 50.
|
|
|
+
|
|
|
+ final AtomicInteger router1Count = new AtomicInteger(0);
|
|
|
+ final AtomicInteger router2Count = new AtomicInteger(0);
|
|
|
+ final AtomicInteger router3Count = new AtomicInteger(0);
|
|
|
+
|
|
|
+ conf.setBoolean(YarnConfiguration.RM_HA_ENABLED, true);
|
|
|
+ conf.setBoolean(YarnConfiguration.FEDERATION_ENABLED, true);
|
|
|
+ conf.setBoolean(YarnConfiguration.FEDERATION_YARN_CLIENT_FAILOVER_RANDOM_ORDER, true);
|
|
|
+ conf.set(YarnConfiguration.RM_HA_IDS, "router0,router1,router2");
|
|
|
+
|
|
|
+ // Create two proxies and mock a RMProxy
|
|
|
+ Proxy mockRouterProxy = new TestProxy((proxy, method, args) -> null);
|
|
|
+
|
|
|
+ Class protocol = ApplicationClientProtocol.class;
|
|
|
+ RMProxy<Proxy> mockRMProxy = mock(RMProxy.class);
|
|
|
+ ConfiguredRMFailoverProxyProvider<Proxy> fpp = new ConfiguredRMFailoverProxyProvider<>();
|
|
|
+
|
|
|
+ // generate two address with different ports.
|
|
|
+ // Default port of yarn RM
|
|
|
+ InetSocketAddress mockRouterAdd = new InetSocketAddress(RM1_PORT);
|
|
|
+
|
|
|
+ // Mock RMProxy methods
|
|
|
+ when(mockRMProxy.getRMAddress(any(YarnConfiguration.class),
|
|
|
+ any(Class.class))).thenReturn(mockRouterAdd);
|
|
|
+ when(mockRMProxy.getProxy(any(YarnConfiguration.class),
|
|
|
+ any(Class.class), eq(mockRouterAdd))).thenReturn(mockRouterProxy);
|
|
|
+
|
|
|
+ // Initialize failover proxy provider and get proxy from it.
|
|
|
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
|
|
|
+ fpp.init(conf, mockRMProxy, protocol);
|
|
|
+ FailoverProxyProvider.ProxyInfo<Proxy> proxy = fpp.getProxy();
|
|
|
+ if ("router0".equals(proxy.proxyInfo)) {
|
|
|
+ router1Count.incrementAndGet();
|
|
|
+ }
|
|
|
+ if ("router1".equals(proxy.proxyInfo)) {
|
|
|
+ router2Count.incrementAndGet();
|
|
|
+ }
|
|
|
+ if ("router2".equals(proxy.proxyInfo)) {
|
|
|
+ router3Count.incrementAndGet();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // router1Count、router2Count、router3Count are
|
|
|
+ // less than NUM_ITERATIONS
|
|
|
+ assertTrue(router1Count.get() < NUM_ITERATIONS);
|
|
|
+ assertTrue(router2Count.get() < NUM_ITERATIONS);
|
|
|
+ assertTrue(router3Count.get() < NUM_ITERATIONS);
|
|
|
+
|
|
|
+ // router1Count、router2Count、router3Count are
|
|
|
+ // more than NUM_ITERATIONS
|
|
|
+ assertTrue(router1Count.get() > 0);
|
|
|
+ assertTrue(router2Count.get() > 0);
|
|
|
+ assertTrue(router3Count.get() > 0);
|
|
|
+
|
|
|
+ // totals(router1Count+router2Count+router3Count ) should be equal NUM_ITERATIONS
|
|
|
+ int totalCount = router1Count.get() + router2Count.get() + router3Count.get();
|
|
|
+ assertEquals(NUM_ITERATIONS, totalCount);
|
|
|
+ }
|
|
|
}
|
|
|
|