|
@@ -18,6 +18,7 @@
|
|
|
|
|
|
package org.apache.hadoop.yarn.client.api.async.impl;
|
|
|
|
|
|
+import com.google.common.base.Supplier;
|
|
|
import static org.mockito.Matchers.anyFloat;
|
|
|
import static org.mockito.Matchers.anyInt;
|
|
|
import static org.mockito.Matchers.anyString;
|
|
@@ -180,7 +181,7 @@ public class TestAMRMClientAsync {
|
|
|
AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
|
|
|
when(client.allocate(anyFloat())).thenThrow(ex);
|
|
|
|
|
|
- AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
|
|
|
asyncClient.init(conf);
|
|
|
asyncClient.start();
|
|
@@ -228,6 +229,41 @@ public class TestAMRMClientAsync {
|
|
|
asyncClient.stop();
|
|
|
}
|
|
|
|
|
|
+ @Test (timeout = 10000)
|
|
|
+ public void testAMRMClientAsyncShutDownWithWaitFor() throws Exception {
|
|
|
+ Configuration conf = new Configuration();
|
|
|
+ final TestCallbackHandler callbackHandler = new TestCallbackHandler();
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
|
|
|
+
|
|
|
+ final AllocateResponse shutDownResponse = createAllocateResponse(
|
|
|
+ new ArrayList<ContainerStatus>(), new ArrayList<Container>(), null);
|
|
|
+ shutDownResponse.setAMCommand(AMCommand.AM_SHUTDOWN);
|
|
|
+ when(client.allocate(anyFloat())).thenReturn(shutDownResponse);
|
|
|
+
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
+ AMRMClientAsync.createAMRMClientAsync(client, 10, callbackHandler);
|
|
|
+ asyncClient.init(conf);
|
|
|
+ asyncClient.start();
|
|
|
+
|
|
|
+ Supplier<Boolean> checker = new Supplier<Boolean>() {
|
|
|
+ @Override
|
|
|
+ public Boolean get() {
|
|
|
+ return callbackHandler.reboot;
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
|
|
|
+ asyncClient.waitFor(checker);
|
|
|
+
|
|
|
+ asyncClient.stop();
|
|
|
+ // stopping should have joined all threads and completed all callbacks
|
|
|
+ Assert.assertTrue(callbackHandler.callbackCount == 0);
|
|
|
+
|
|
|
+ verify(client, times(1)).allocate(anyFloat());
|
|
|
+ asyncClient.stop();
|
|
|
+ }
|
|
|
+
|
|
|
@Test (timeout = 5000)
|
|
|
public void testCallAMRMClientAsyncStopFromCallbackHandler()
|
|
|
throws YarnException, IOException, InterruptedException {
|
|
@@ -262,6 +298,40 @@ public class TestAMRMClientAsync {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @Test (timeout = 5000)
|
|
|
+ public void testCallAMRMClientAsyncStopFromCallbackHandlerWithWaitFor()
|
|
|
+ throws YarnException, IOException, InterruptedException {
|
|
|
+ Configuration conf = new Configuration();
|
|
|
+ final TestCallbackHandler2 callbackHandler = new TestCallbackHandler2();
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
|
|
|
+
|
|
|
+ List<ContainerStatus> completed = Arrays.asList(
|
|
|
+ ContainerStatus.newInstance(newContainerId(0, 0, 0, 0),
|
|
|
+ ContainerState.COMPLETE, "", 0));
|
|
|
+ final AllocateResponse response = createAllocateResponse(completed,
|
|
|
+ new ArrayList<Container>(), null);
|
|
|
+
|
|
|
+ when(client.allocate(anyFloat())).thenReturn(response);
|
|
|
+
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
+ AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
|
|
|
+ callbackHandler.asynClient = asyncClient;
|
|
|
+ asyncClient.init(conf);
|
|
|
+ asyncClient.start();
|
|
|
+
|
|
|
+ Supplier<Boolean> checker = new Supplier<Boolean>() {
|
|
|
+ @Override
|
|
|
+ public Boolean get() {
|
|
|
+ return callbackHandler.notify;
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
|
|
|
+ asyncClient.waitFor(checker);
|
|
|
+ Assert.assertTrue(checker.get());
|
|
|
+ }
|
|
|
+
|
|
|
void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws
|
|
|
InterruptedException, YarnException, IOException {
|
|
|
Configuration conf = new Configuration();
|
|
@@ -342,7 +412,7 @@ public class TestAMRMClientAsync {
|
|
|
private volatile List<ContainerStatus> completedContainers;
|
|
|
private volatile List<Container> allocatedContainers;
|
|
|
Exception savedException = null;
|
|
|
- boolean reboot = false;
|
|
|
+ volatile boolean reboot = false;
|
|
|
Object notifier = new Object();
|
|
|
|
|
|
int callbackCount = 0;
|
|
@@ -432,7 +502,7 @@ public class TestAMRMClientAsync {
|
|
|
@SuppressWarnings("rawtypes")
|
|
|
AMRMClientAsync asynClient;
|
|
|
boolean stop = true;
|
|
|
- boolean notify = false;
|
|
|
+ volatile boolean notify = false;
|
|
|
boolean throwOutException = false;
|
|
|
|
|
|
@Override
|