|
@@ -23,7 +23,10 @@ import static org.mockito.Matchers.anyInt;
|
|
|
import static org.mockito.Matchers.anyString;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
|
|
|
+import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.List;
|
|
@@ -218,6 +221,65 @@ public class TestAMRMClientAsync {
|
|
|
Assert.assertTrue(callbackHandler.callbackCount == 0);
|
|
|
}
|
|
|
|
|
|
+ @Test (timeout = 10000)
|
|
|
+ public void testAMRMClientAsyncShutDown() throws Exception {
|
|
|
+ Configuration conf = new Configuration();
|
|
|
+ TestCallbackHandler callbackHandler = new TestCallbackHandler();
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
|
|
|
+
|
|
|
+ final AllocateResponse shutDownResponse = createAllocateResponse(
|
|
|
+ new ArrayList<ContainerStatus>(), new ArrayList<Container>(), null);
|
|
|
+ shutDownResponse.setAMCommand(AMCommand.AM_SHUTDOWN);
|
|
|
+ when(client.allocate(anyFloat())).thenReturn(shutDownResponse);
|
|
|
+
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
+ AMRMClientAsync.createAMRMClientAsync(client, 10, callbackHandler);
|
|
|
+ asyncClient.init(conf);
|
|
|
+ asyncClient.start();
|
|
|
+
|
|
|
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
|
|
|
+
|
|
|
+ Thread.sleep(50);
|
|
|
+
|
|
|
+ verify(client, times(1)).allocate(anyFloat());
|
|
|
+ asyncClient.stop();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test (timeout = 5000)
|
|
|
+ public void testCallAMRMClientAsyncStopFromCallbackHandler()
|
|
|
+ throws YarnException, IOException, InterruptedException {
|
|
|
+ Configuration conf = new Configuration();
|
|
|
+ TestCallbackHandler2 callbackHandler = new TestCallbackHandler2();
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
|
|
|
+
|
|
|
+ List<ContainerStatus> completed = Arrays.asList(
|
|
|
+ ContainerStatus.newInstance(newContainerId(0, 0, 0, 0),
|
|
|
+ ContainerState.COMPLETE, "", 0));
|
|
|
+ final AllocateResponse response = createAllocateResponse(completed,
|
|
|
+ new ArrayList<Container>(), null);
|
|
|
+
|
|
|
+ when(client.allocate(anyFloat())).thenReturn(response);
|
|
|
+
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient =
|
|
|
+ AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
|
|
|
+ callbackHandler.registerAsyncClient(asyncClient);
|
|
|
+ asyncClient.init(conf);
|
|
|
+ asyncClient.start();
|
|
|
+
|
|
|
+ synchronized (callbackHandler.notifier) {
|
|
|
+ asyncClient.registerApplicationMaster("localhost", 1234, null);
|
|
|
+ while(callbackHandler.stop == false) {
|
|
|
+ try {
|
|
|
+ callbackHandler.notifier.wait();
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private AllocateResponse createAllocateResponse(
|
|
|
List<ContainerStatus> completed, List<Container> allocated,
|
|
|
List<NMToken> nmTokens) {
|
|
@@ -323,4 +385,41 @@ public class TestAMRMClientAsync {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private class TestCallbackHandler2 implements AMRMClientAsync.CallbackHandler {
|
|
|
+ Object notifier = new Object();
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ AMRMClientAsync asynClient;
|
|
|
+ boolean stop = false;
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onContainersCompleted(List<ContainerStatus> statuses) {}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onContainersAllocated(List<Container> containers) {}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onShutdownRequest() {}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onNodesUpdated(List<NodeReport> updatedNodes) {}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public float getProgress() {
|
|
|
+ asynClient.stop();
|
|
|
+ stop = true;
|
|
|
+ synchronized (notifier) {
|
|
|
+ notifier.notifyAll();
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onError(Exception e) {}
|
|
|
+
|
|
|
+ public void registerAsyncClient(
|
|
|
+ AMRMClientAsync<ContainerRequest> asyncClient) {
|
|
|
+ this.asynClient = asyncClient;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|