Jelajahi Sumber

YARN-10465. Support getNodeToLabels, getLabelsToNodes, getClusterNodeLabels API's for Federation (#4317)

slfan1989 3 tahun lalu
induk
melakukan
8dd3ef1f08

+ 99 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/RouterMetrics.java

@@ -57,6 +57,12 @@ public final class RouterMetrics {
   private MutableGaugeInt numGetClusterMetricsFailedRetrieved;
   @Metric("# of getClusterNodes failed to be retrieved")
   private MutableGaugeInt numGetClusterNodesFailedRetrieved;
+  @Metric("# of getNodeToLabels failed to be retrieved")
+  private MutableGaugeInt numGetNodeToLabelsFailedRetrieved;
+  @Metric("# of getNodeToLabels failed to be retrieved")
+  private MutableGaugeInt numGetLabelsToNodesFailedRetrieved;
+  @Metric("# of getClusterNodeLabels failed to be retrieved")
+  private MutableGaugeInt numGetClusterNodeLabelsFailedRetrieved;
 
   // Aggregate metrics are shared, and don't have to be looked up per call
   @Metric("Total number of successful Submitted apps and latency(ms)")
@@ -78,6 +84,12 @@ public final class RouterMetrics {
   private MutableRate totalSucceededGetClusterMetricsRetrieved;
   @Metric("Total number of successful Retrieved getClusterNodes and latency(ms)")
   private MutableRate totalSucceededGetClusterNodesRetrieved;
+  @Metric("Total number of successful Retrieved getNodeToLabels and latency(ms)")
+  private MutableRate totalSucceededGetNodeToLabelsRetrieved;
+  @Metric("Total number of successful Retrieved getNodeToLabels and latency(ms)")
+  private MutableRate totalSucceededGetLabelsToNodesRetrieved;
+  @Metric("Total number of successful Retrieved getClusterNodeLabels and latency(ms)")
+  private MutableRate totalSucceededGetClusterNodeLabelsRetrieved;
 
   /**
    * Provide quantile counters for all latencies.
@@ -90,6 +102,9 @@ public final class RouterMetrics {
   private MutableQuantiles getApplicationAttemptReportLatency;
   private MutableQuantiles getClusterMetricsLatency;
   private MutableQuantiles getClusterNodesLatency;
+  private MutableQuantiles getNodeToLabelsLatency;
+  private MutableQuantiles getLabelToNodesLatency;
+  private MutableQuantiles getClusterNodeLabelsLatency;
 
   private static volatile RouterMetrics INSTANCE = null;
   private static MetricsRegistry registry;
@@ -120,6 +135,18 @@ public final class RouterMetrics {
     getClusterNodesLatency =
         registry.newQuantiles("getClusterNodesLatency",
             "latency of get cluster nodes", "ops", "latency", 10);
+
+    getNodeToLabelsLatency =
+        registry.newQuantiles("getNodeToLabelsLatency",
+            "latency of get node labels", "ops", "latency", 10);
+
+    getLabelToNodesLatency =
+        registry.newQuantiles("getLabelToNodesLatency",
+            "latency of get label nodes", "ops", "latency", 10);
+
+    getClusterNodeLabelsLatency =
+        registry.newQuantiles("getClusterNodeLabelsLatency",
+            "latency of get cluster node labels", "ops", "latency", 10);
   }
 
   public static RouterMetrics getMetrics() {
@@ -181,6 +208,21 @@ public final class RouterMetrics {
     return totalSucceededGetClusterNodesRetrieved.lastStat().numSamples();
   }
 
+  @VisibleForTesting
+  public long getNumSucceededGetNodeToLabelsRetrieved(){
+    return totalSucceededGetNodeToLabelsRetrieved.lastStat().numSamples();
+  }
+
+  @VisibleForTesting
+  public long getNumSucceededGetLabelsToNodesRetrieved(){
+    return totalSucceededGetLabelsToNodesRetrieved.lastStat().numSamples();
+  }
+
+  @VisibleForTesting
+  public long getNumSucceededGetClusterNodeLabelsRetrieved(){
+    return totalSucceededGetClusterNodeLabelsRetrieved.lastStat().numSamples();
+  }
+
   @VisibleForTesting
   public double getLatencySucceededAppsCreated() {
     return totalSucceededAppsCreated.lastStat().mean();
@@ -221,6 +263,21 @@ public final class RouterMetrics {
     return totalSucceededGetClusterNodesRetrieved.lastStat().mean();
   }
 
+  @VisibleForTesting
+  public double getLatencySucceededGetNodeToLabelsRetrieved() {
+    return totalSucceededGetNodeToLabelsRetrieved.lastStat().mean();
+  }
+
+  @VisibleForTesting
+  public double getLatencySucceededGetLabelsToNodesRetrieved() {
+    return totalSucceededGetLabelsToNodesRetrieved.lastStat().mean();
+  }
+
+  @VisibleForTesting
+  public double getLatencySucceededGetClusterNodeLabelsRetrieved() {
+    return totalSucceededGetClusterNodeLabelsRetrieved.lastStat().mean();
+  }
+
   @VisibleForTesting
   public int getAppsFailedCreated() {
     return numAppsFailedCreated.value();
@@ -261,6 +318,21 @@ public final class RouterMetrics {
     return numGetClusterNodesFailedRetrieved.value();
   }
 
+  @VisibleForTesting
+  public int getNodeToLabelsFailedRetrieved() {
+    return numGetNodeToLabelsFailedRetrieved.value();
+  }
+
+  @VisibleForTesting
+  public int getLabelsToNodesFailedRetrieved() {
+    return numGetLabelsToNodesFailedRetrieved.value();
+  }
+
+  @VisibleForTesting
+  public int getGetClusterNodeLabelsFailedRetrieved() {
+    return numGetClusterNodeLabelsFailedRetrieved.value();
+  }
+
   public void succeededAppsCreated(long duration) {
     totalSucceededAppsCreated.add(duration);
     getNewApplicationLatency.add(duration);
@@ -301,6 +373,21 @@ public final class RouterMetrics {
     getClusterNodesLatency.add(duration);
   }
 
+  public void succeededGetNodeToLabelsRetrieved(long duration) {
+    totalSucceededGetNodeToLabelsRetrieved.add(duration);
+    getNodeToLabelsLatency.add(duration);
+  }
+
+  public void succeededGetLabelsToNodesRetrieved(long duration) {
+    totalSucceededGetLabelsToNodesRetrieved.add(duration);
+    getLabelToNodesLatency.add(duration);
+  }
+
+  public void succeededGetClusterNodeLabelsRetrieved(long duration) {
+    totalSucceededGetClusterNodeLabelsRetrieved.add(duration);
+    getClusterNodeLabelsLatency.add(duration);
+  }
+
   public void incrAppsFailedCreated() {
     numAppsFailedCreated.incr();
   }
@@ -332,4 +419,16 @@ public final class RouterMetrics {
   public void incrClusterNodesFailedRetrieved() {
     numGetClusterNodesFailedRetrieved.incr();
   }
+
+  public void incrNodeToLabelsFailedRetrieved() {
+    numGetNodeToLabelsFailedRetrieved.incr();
+  }
+
+  public void incrLabelsToNodesFailedRetrieved() {
+    numGetLabelsToNodesFailedRetrieved.incr();
+  }
+
+  public void incrClusterNodeLabelsFailedRetrieved() {
+    numGetClusterNodeLabelsFailedRetrieved.incr();
+  }
 }

+ 80 - 3
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/clientrm/FederationClientInterceptor.java

@@ -21,6 +21,7 @@ package org.apache.hadoop.yarn.server.router.clientrm;
 import org.apache.hadoop.thirdparty.com.google.common.collect.Maps;
 import org.apache.hadoop.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder;
 import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -38,6 +39,7 @@ import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 import org.apache.commons.lang3.NotImplementedException;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.CommonConfigurationKeys;
@@ -867,22 +869,97 @@ public class FederationClientInterceptor
     throw new NotImplementedException("Code is not implemented");
   }
 
+  private <R> Collection<R> invokeAppClientProtocolMethod(
+      Boolean filterInactiveSubClusters, ClientMethod request, Class<R> clazz)
+          throws YarnException, RuntimeException {
+    Map<SubClusterId, SubClusterInfo> subClusters =
+        federationFacade.getSubClusters(filterInactiveSubClusters);
+    return subClusters.keySet().stream().map(subClusterId -> {
+      try {
+        ApplicationClientProtocol protocol = getClientRMProxyForSubCluster(subClusterId);
+        Method method = ApplicationClientProtocol.class.
+            getMethod(request.getMethodName(), request.getTypes());
+        return clazz.cast(method.invoke(protocol, request.getParams()));
+      } catch (YarnException | NoSuchMethodException |
+               IllegalAccessException | InvocationTargetException ex) {
+        throw new RuntimeException(ex);
+      }
+    }).collect(Collectors.toList());
+  }
+
   @Override
   public GetNodesToLabelsResponse getNodeToLabels(
       GetNodesToLabelsRequest request) throws YarnException, IOException {
-    throw new NotImplementedException("Code is not implemented");
+    if (request == null) {
+      routerMetrics.incrNodeToLabelsFailedRetrieved();
+      RouterServerUtil.logAndThrowException("Missing getNodesToLabels request.", null);
+    }
+    long startTime = clock.getTime();
+    ClientMethod remoteMethod = new ClientMethod("getNodeToLabels",
+         new Class[] {GetNodesToLabelsRequest.class}, new Object[] {request});
+    Collection<GetNodesToLabelsResponse> clusterNodes;
+    try {
+      clusterNodes = invokeAppClientProtocolMethod(true, remoteMethod,
+          GetNodesToLabelsResponse.class);
+    } catch (Exception ex) {
+      routerMetrics.incrNodeToLabelsFailedRetrieved();
+      LOG.error("Unable to get label node due to exception.", ex);
+      throw ex;
+    }
+    long stopTime = clock.getTime();
+    routerMetrics.succeededGetNodeToLabelsRetrieved(stopTime - startTime);
+    // Merge the NodesToLabelsResponse
+    return RouterYarnClientUtils.mergeNodesToLabelsResponse(clusterNodes);
   }
 
   @Override
   public GetLabelsToNodesResponse getLabelsToNodes(
       GetLabelsToNodesRequest request) throws YarnException, IOException {
-    throw new NotImplementedException("Code is not implemented");
+    if (request == null) {
+      routerMetrics.incrLabelsToNodesFailedRetrieved();
+      RouterServerUtil.logAndThrowException("Missing getLabelsToNodes request.", null);
+    }
+    long startTime = clock.getTime();
+    ClientMethod remoteMethod = new ClientMethod("getLabelsToNodes",
+         new Class[] {GetLabelsToNodesRequest.class}, new Object[] {request});
+    Collection<GetLabelsToNodesResponse> labelNodes;
+    try {
+      labelNodes = invokeAppClientProtocolMethod(true, remoteMethod,
+          GetLabelsToNodesResponse.class);
+    } catch (Exception ex) {
+      routerMetrics.incrLabelsToNodesFailedRetrieved();
+      LOG.error("Unable to get label node due to exception.", ex);
+      throw ex;
+    }
+    long stopTime = clock.getTime();
+    routerMetrics.succeededGetLabelsToNodesRetrieved(stopTime - startTime);
+    // Merge the LabelsToNodesResponse
+    return RouterYarnClientUtils.mergeLabelsToNodes(labelNodes);
   }
 
   @Override
   public GetClusterNodeLabelsResponse getClusterNodeLabels(
       GetClusterNodeLabelsRequest request) throws YarnException, IOException {
-    throw new NotImplementedException("Code is not implemented");
+    if (request == null) {
+      routerMetrics.incrClusterNodeLabelsFailedRetrieved();
+      RouterServerUtil.logAndThrowException("Missing getClusterNodeLabels request.", null);
+    }
+    long startTime = clock.getTime();
+    ClientMethod remoteMethod = new ClientMethod("getClusterNodeLabels",
+         new Class[] {GetClusterNodeLabelsRequest.class}, new Object[] {request});
+    Collection<GetClusterNodeLabelsResponse> nodeLabels;
+    try {
+      nodeLabels = invokeAppClientProtocolMethod(true, remoteMethod,
+          GetClusterNodeLabelsResponse.class);
+    } catch (Exception ex) {
+      routerMetrics.incrClusterNodeLabelsFailedRetrieved();
+      LOG.error("Unable to get cluster nodeLabels due to exception.", ex);
+      throw ex;
+    }
+    long stopTime = clock.getTime();
+    routerMetrics.succeededGetClusterNodeLabelsRetrieved(stopTime - startTime);
+    // Merge the ClusterNodeLabelsResponse
+    return RouterYarnClientUtils.mergeClusterNodeLabelsResponse(nodeLabels);
   }
 
   /**

+ 78 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/clientrm/RouterYarnClientUtils.java

@@ -22,15 +22,22 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.List;
 import java.util.ArrayList;
+import java.util.Set;
+import java.util.HashSet;
 
 import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationsResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodesResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetNodesToLabelsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetLabelsToNodesResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodeLabelsResponse;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationReport;
 import org.apache.hadoop.yarn.api.records.ApplicationResourceUsageReport;
 import org.apache.hadoop.yarn.api.records.YarnClusterMetrics;
 import org.apache.hadoop.yarn.api.records.NodeReport;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.NodeLabel;
 import org.apache.hadoop.yarn.server.uam.UnmanagedApplicationManager;
 import org.apache.hadoop.yarn.util.Records;
 import org.apache.hadoop.yarn.util.resource.Resources;
@@ -218,4 +225,75 @@ public final class RouterYarnClientUtils {
     clusterNodesResponse.setNodeReports(nodeReports);
     return clusterNodesResponse;
   }
+
+  /**
+   * Merges a list of GetNodesToLabelsResponse.
+   *
+   * @param responses a list of GetNodesToLabelsResponse to merge.
+   * @return the merged GetNodesToLabelsResponse.
+   */
+  public static GetNodesToLabelsResponse mergeNodesToLabelsResponse(
+      Collection<GetNodesToLabelsResponse> responses) {
+    GetNodesToLabelsResponse nodesToLabelsResponse = Records.newRecord(
+         GetNodesToLabelsResponse.class);
+    Map<NodeId, Set<String>> nodesToLabelMap = new HashMap<>();
+    for (GetNodesToLabelsResponse response : responses) {
+      if (response != null && response.getNodeToLabels() != null) {
+        nodesToLabelMap.putAll(response.getNodeToLabels());
+      }
+    }
+    nodesToLabelsResponse.setNodeToLabels(nodesToLabelMap);
+    return nodesToLabelsResponse;
+  }
+
+  /**
+   * Merges a list of GetLabelsToNodesResponse.
+   *
+   * @param responses a list of GetLabelsToNodesResponse to merge.
+   * @return the merged GetLabelsToNodesResponse.
+   */
+  public static GetLabelsToNodesResponse mergeLabelsToNodes(
+      Collection<GetLabelsToNodesResponse> responses){
+    GetLabelsToNodesResponse labelsToNodesResponse = Records.newRecord(
+        GetLabelsToNodesResponse.class);
+    Map<String, Set<NodeId>> labelsToNodesMap = new HashMap<>();
+    for (GetLabelsToNodesResponse response : responses) {
+      if (response != null && response.getLabelsToNodes() != null) {
+        Map<String, Set<NodeId>> clusterLabelsToNodesMap = response.getLabelsToNodes();
+        for (Map.Entry<String, Set<NodeId>> entry : clusterLabelsToNodesMap.entrySet()) {
+          String label = entry.getKey();
+          Set<NodeId> clusterNodes = entry.getValue();
+          if (labelsToNodesMap.containsKey(label)) {
+            Set<NodeId> allNodes = labelsToNodesMap.get(label);
+            allNodes.addAll(clusterNodes);
+          } else {
+            labelsToNodesMap.put(label, clusterNodes);
+          }
+        }
+      }
+    }
+    labelsToNodesResponse.setLabelsToNodes(labelsToNodesMap);
+    return labelsToNodesResponse;
+  }
+
+  /**
+   * Merges a list of GetClusterNodeLabelsResponse.
+   *
+   * @param responses a list of GetClusterNodeLabelsResponse to merge.
+   * @return the merged GetClusterNodeLabelsResponse.
+   */
+  public static GetClusterNodeLabelsResponse mergeClusterNodeLabelsResponse(
+      Collection<GetClusterNodeLabelsResponse> responses) {
+    GetClusterNodeLabelsResponse nodeLabelsResponse = Records.newRecord(
+        GetClusterNodeLabelsResponse.class);
+    Set<NodeLabel> nodeLabelsList = new HashSet<>();
+    for (GetClusterNodeLabelsResponse response : responses) {
+      if (response != null && response.getNodeLabelList() != null) {
+        nodeLabelsList.addAll(response.getNodeLabelList());
+      }
+    }
+    nodeLabelsResponse.setNodeLabelList(new ArrayList<>(nodeLabelsList));
+    return nodeLabelsResponse;
+  }
 }
+

+ 92 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/TestRouterMetrics.java

@@ -353,6 +353,21 @@ public class TestRouterMetrics {
       LOG.info("Mocked: failed getClusterNodes call");
       metrics.incrClusterNodesFailedRetrieved();
     }
+
+    public void getNodeToLabels() {
+      LOG.info("Mocked: failed getNodeToLabels call");
+      metrics.incrNodeToLabelsFailedRetrieved();
+    }
+
+    public void getLabelToNodes() {
+      LOG.info("Mocked: failed getLabelToNodes call");
+      metrics.incrLabelsToNodesFailedRetrieved();
+    }
+
+    public void getClusterNodeLabels() {
+      LOG.info("Mocked: failed getClusterNodeLabels call");
+      metrics.incrClusterNodeLabelsFailedRetrieved();
+    }
   }
 
   // Records successes for all calls
@@ -404,6 +419,21 @@ public class TestRouterMetrics {
       LOG.info("Mocked: successful getClusterNodes call with duration {}", duration);
       metrics.succeededGetClusterNodesRetrieved(duration);
     }
+
+    public void getNodeToLabels(long duration) {
+      LOG.info("Mocked: successful getNodeToLabels call with duration {}", duration);
+      metrics.succeededGetNodeToLabelsRetrieved(duration);
+    }
+
+    public void getLabelToNodes(long duration) {
+      LOG.info("Mocked: successful getLabelToNodes call with duration {}", duration);
+      metrics.succeededGetLabelsToNodesRetrieved(duration);
+    }
+
+    public void getClusterNodeLabels(long duration) {
+      LOG.info("Mocked: successful getClusterNodeLabels call with duration {}", duration);
+      metrics.succeededGetClusterNodeLabelsRetrieved(duration);
+    }
   }
 
   @Test
@@ -425,4 +455,66 @@ public class TestRouterMetrics {
     badSubCluster.getClusterNodes();
     Assert.assertEquals(totalBadBefore + 1, metrics.getClusterNodesFailedRetrieved());
   }
+
+  @Test
+  public void testSucceededGetNodeToLabels() {
+    long totalGoodBefore = metrics.getNumSucceededGetNodeToLabelsRetrieved();
+    goodSubCluster.getNodeToLabels(150);
+    Assert.assertEquals(totalGoodBefore + 1, metrics.getNumSucceededGetNodeToLabelsRetrieved());
+    Assert.assertEquals(150, metrics.getLatencySucceededGetNodeToLabelsRetrieved(),
+        ASSERT_DOUBLE_DELTA);
+    goodSubCluster.getNodeToLabels(300);
+    Assert.assertEquals(totalGoodBefore + 2, metrics.getNumSucceededGetNodeToLabelsRetrieved());
+    Assert.assertEquals(225, metrics.getLatencySucceededGetNodeToLabelsRetrieved(),
+        ASSERT_DOUBLE_DELTA);
+  }
+
+  @Test
+  public void testGetNodeToLabelsFailed() {
+    long totalBadBefore = metrics.getNodeToLabelsFailedRetrieved();
+    badSubCluster.getNodeToLabels();
+    Assert.assertEquals(totalBadBefore + 1, metrics.getNodeToLabelsFailedRetrieved());
+  }
+
+  @Test
+  public void testSucceededLabelsToNodes() {
+    long totalGoodBefore = metrics.getNumSucceededGetLabelsToNodesRetrieved();
+    goodSubCluster.getLabelToNodes(150);
+    Assert.assertEquals(totalGoodBefore + 1, metrics.getNumSucceededGetLabelsToNodesRetrieved());
+    Assert.assertEquals(150, metrics.getLatencySucceededGetLabelsToNodesRetrieved(),
+        ASSERT_DOUBLE_DELTA);
+    goodSubCluster.getLabelToNodes(300);
+    Assert.assertEquals(totalGoodBefore + 2, metrics.getNumSucceededGetLabelsToNodesRetrieved());
+    Assert.assertEquals(225, metrics.getLatencySucceededGetLabelsToNodesRetrieved(),
+        ASSERT_DOUBLE_DELTA);
+  }
+
+  @Test
+  public void testGetLabelsToNodesFailed() {
+    long totalBadBefore = metrics.getLabelsToNodesFailedRetrieved();
+    badSubCluster.getLabelToNodes();
+    Assert.assertEquals(totalBadBefore + 1, metrics.getLabelsToNodesFailedRetrieved());
+  }
+
+  @Test
+  public void testSucceededClusterNodeLabels() {
+    long totalGoodBefore = metrics.getNumSucceededGetClusterNodeLabelsRetrieved();
+    goodSubCluster.getClusterNodeLabels(150);
+    Assert.assertEquals(totalGoodBefore + 1,
+        metrics.getNumSucceededGetClusterNodeLabelsRetrieved());
+    Assert.assertEquals(150,
+        metrics.getLatencySucceededGetClusterNodeLabelsRetrieved(), ASSERT_DOUBLE_DELTA);
+    goodSubCluster.getClusterNodeLabels(300);
+    Assert.assertEquals(totalGoodBefore + 2,
+        metrics.getNumSucceededGetClusterNodeLabelsRetrieved());
+    Assert.assertEquals(225, metrics.getLatencySucceededGetClusterNodeLabelsRetrieved(),
+        ASSERT_DOUBLE_DELTA);
+  }
+
+  @Test
+  public void testClusterNodeLabelsFailed() {
+    long totalBadBefore = metrics.getGetClusterNodeLabelsFailedRetrieved();
+    badSubCluster.getClusterNodeLabels();
+    Assert.assertEquals(totalBadBefore + 1, metrics.getGetClusterNodeLabelsFailedRetrieved());
+  }
 }

+ 42 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/clientrm/TestFederationClientInterceptor.java

@@ -46,6 +46,12 @@ import org.apache.hadoop.yarn.api.protocolrecords.SubmitApplicationRequest;
 import org.apache.hadoop.yarn.api.protocolrecords.SubmitApplicationResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodesResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodesRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.GetNodesToLabelsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetNodesToLabelsRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.GetLabelsToNodesResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetLabelsToNodesRequest;
+import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodeLabelsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodeLabelsRequest;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
@@ -655,4 +661,40 @@ public class TestFederationClientInterceptor extends BaseRouterClientRMTest {
         interceptor.getClusterNodes(GetClusterNodesRequest.newInstance());
     Assert.assertEquals(subClusters.size(), response.getNodeReports().size());
   }
+
+  @Test
+  public void testGetNodeToLabelsRequest() throws Exception {
+    LOG.info("Test FederationClientInterceptor :  Get Node To Labels request");
+    // null request
+    LambdaTestUtils.intercept(YarnException.class, "Missing getNodesToLabels request.",
+        () -> interceptor.getNodeToLabels(null));
+    // normal request.
+    GetNodesToLabelsResponse response =
+        interceptor.getNodeToLabels(GetNodesToLabelsRequest.newInstance());
+    Assert.assertEquals(0, response.getNodeToLabels().size());
+  }
+
+  @Test
+  public void testGetLabelsToNodesRequest() throws Exception {
+    LOG.info("Test FederationClientInterceptor :  Get Labels To Node request");
+    // null request
+    LambdaTestUtils.intercept(YarnException.class, "Missing getLabelsToNodes request.",
+        () -> interceptor.getLabelsToNodes(null));
+    // normal request.
+    GetLabelsToNodesResponse response =
+        interceptor.getLabelsToNodes(GetLabelsToNodesRequest.newInstance());
+    Assert.assertEquals(0, response.getLabelsToNodes().size());
+  }
+
+  @Test
+  public void testClusterNodeLabelsRequest() throws Exception {
+    LOG.info("Test FederationClientInterceptor :  Get Cluster NodeLabels request");
+    // null request
+    LambdaTestUtils.intercept(YarnException.class, "Missing getClusterNodeLabels request.",
+        () -> interceptor.getClusterNodeLabels(null));
+    // normal request.
+    GetClusterNodeLabelsResponse response =
+        interceptor.getClusterNodeLabels(GetClusterNodeLabelsRequest.newInstance());
+    Assert.assertEquals(0, response.getNodeLabelList().size());
+  }
 }

+ 154 - 0
hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/clientrm/TestRouterYarnClientUtils.java

@@ -20,9 +20,18 @@ package org.apache.hadoop.yarn.server.router.clientrm;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Set;
+import java.util.Map;
+import java.util.HashMap;
+import java.util.HashSet;
 
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;
 import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationsResponse;
 import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetNodesToLabelsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetClusterNodeLabelsResponse;
+import org.apache.hadoop.yarn.api.protocolrecords.GetLabelsToNodesResponse;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationReport;
@@ -31,6 +40,9 @@ import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.api.records.YarnApplicationState;
 import org.apache.hadoop.yarn.api.records.YarnClusterMetrics;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.NodeLabel;
+import org.apache.hadoop.yarn.util.Records;
 import org.apache.hadoop.yarn.server.uam.UnmanagedApplicationManager;
 import org.junit.Assert;
 import org.junit.Test;
@@ -213,4 +225,146 @@ public class TestRouterYarnClientUtils {
 
     return GetApplicationsResponse.newInstance(applications);
   }
+
+  @Test
+  public void testMergeNodesToLabelsResponse() {
+    NodeId node1 = NodeId.fromString("SubCluster1Node1:1111");
+    NodeId node2 = NodeId.fromString("SubCluster1Node2:2222");
+    NodeId node3 = NodeId.fromString("SubCluster2Node1:1111");
+
+    Map<NodeId, Set<String>> nodeLabelsMapSC1 = new HashMap<>();
+    nodeLabelsMapSC1.put(node1, ImmutableSet.of("node1"));
+    nodeLabelsMapSC1.put(node2, ImmutableSet.of("node2"));
+    nodeLabelsMapSC1.put(node3, ImmutableSet.of("node3"));
+
+    // normal response
+    GetNodesToLabelsResponse response1 = Records.newRecord(
+        GetNodesToLabelsResponse.class);
+    response1.setNodeToLabels(nodeLabelsMapSC1);
+
+    // empty response
+    Map<NodeId, Set<String>> nodeLabelsMapSC2 = new HashMap<>();
+    GetNodesToLabelsResponse response2 = Records.newRecord(
+        GetNodesToLabelsResponse.class);
+    response2.setNodeToLabels(nodeLabelsMapSC2);
+
+    // null response
+    GetNodesToLabelsResponse response3 = null;
+
+    Map<NodeId, Set<String>> expectedResponse = new HashMap<>();
+    expectedResponse.put(node1, ImmutableSet.of("node1"));
+    expectedResponse.put(node2, ImmutableSet.of("node2"));
+    expectedResponse.put(node3, ImmutableSet.of("node3"));
+
+    List<GetNodesToLabelsResponse> responses = new ArrayList<>();
+    responses.add(response1);
+    responses.add(response2);
+    responses.add(response3);
+
+    GetNodesToLabelsResponse response = RouterYarnClientUtils.
+        mergeNodesToLabelsResponse(responses);
+    Assert.assertEquals(expectedResponse, response.getNodeToLabels());
+  }
+
+  @Test
+  public void testMergeClusterNodeLabelsResponse() {
+    NodeLabel nodeLabel1 = NodeLabel.newInstance("nodeLabel1");
+    NodeLabel nodeLabel2 = NodeLabel.newInstance("nodeLabel2");
+    NodeLabel nodeLabel3 = NodeLabel.newInstance("nodeLabel3");
+
+    // normal response
+    List<NodeLabel> nodeLabelListSC1 = new ArrayList<>();
+    nodeLabelListSC1.add(nodeLabel1);
+    nodeLabelListSC1.add(nodeLabel2);
+    nodeLabelListSC1.add(nodeLabel3);
+
+    GetClusterNodeLabelsResponse response1 = Records.newRecord(
+        GetClusterNodeLabelsResponse.class);
+    response1.setNodeLabelList(nodeLabelListSC1);
+
+    // empty response
+    List<NodeLabel> nodeLabelListSC2 = new ArrayList<>();
+
+    GetClusterNodeLabelsResponse response2 = Records.newRecord(
+        GetClusterNodeLabelsResponse.class);
+    response2.setNodeLabelList(nodeLabelListSC2);
+
+    // null response
+    GetClusterNodeLabelsResponse response3 = null;
+
+    List<GetClusterNodeLabelsResponse> responses = new ArrayList<>();
+    responses.add(response1);
+    responses.add(response2);
+    responses.add(response3);
+
+    List<NodeLabel> expectedResponse = new ArrayList<>();
+    expectedResponse.add(nodeLabel1);
+    expectedResponse.add(nodeLabel2);
+    expectedResponse.add(nodeLabel3);
+
+    GetClusterNodeLabelsResponse response = RouterYarnClientUtils.
+        mergeClusterNodeLabelsResponse(responses);
+    Assert.assertTrue(CollectionUtils.isEqualCollection(expectedResponse,
+        response.getNodeLabelList()));
+  }
+
+  @Test
+  public void testMergeLabelsToNodes(){
+    NodeId node1 = NodeId.fromString("SubCluster1Node1:1111");
+    NodeId node2 = NodeId.fromString("SubCluster1Node2:2222");
+    NodeId node3 = NodeId.fromString("SubCluster2node1:1111");
+    NodeId node4 = NodeId.fromString("SubCluster2node2:2222");
+
+    Map<String, Set<NodeId>> labelsToNodesSC1 = new HashMap<>();
+
+    Set<NodeId> nodeIdSet1 = new HashSet<>();
+    nodeIdSet1.add(node1);
+    nodeIdSet1.add(node2);
+    labelsToNodesSC1.put("Label1", nodeIdSet1);
+
+    // normal response
+    GetLabelsToNodesResponse response1 = Records.newRecord(
+        GetLabelsToNodesResponse.class);
+    response1.setLabelsToNodes(labelsToNodesSC1);
+    Map<String, Set<NodeId>> labelsToNodesSC2 = new HashMap<>();
+    Set<NodeId> nodeIdSet2 = new HashSet<>();
+    nodeIdSet2.add(node3);
+    Set<NodeId> nodeIdSet3 = new HashSet<>();
+    nodeIdSet3.add(node4);
+    labelsToNodesSC2.put("Label1", nodeIdSet2);
+    labelsToNodesSC2.put("Label2", nodeIdSet3);
+
+    GetLabelsToNodesResponse response2 = Records.newRecord(
+        GetLabelsToNodesResponse.class);
+    response2.setLabelsToNodes(labelsToNodesSC2);
+
+    // empty response
+    GetLabelsToNodesResponse response3 = Records.newRecord(
+        GetLabelsToNodesResponse.class);
+
+    // null response
+    GetLabelsToNodesResponse response4 = null;
+
+    List<GetLabelsToNodesResponse> responses = new ArrayList<>();
+    responses.add(response1);
+    responses.add(response2);
+    responses.add(response3);
+    responses.add(response4);
+
+    Map<String, Set<NodeId>> expectedResponse = new HashMap<>();
+    Set<NodeId> nodeIdMergedSet1 = new HashSet<>();
+    nodeIdMergedSet1.add(node1);
+    nodeIdMergedSet1.add(node2);
+    nodeIdMergedSet1.add(node3);
+
+    Set<NodeId> nodeIdMergedSet2 = new HashSet<>();
+    nodeIdMergedSet2.add(node4);
+    expectedResponse.put("Label1", nodeIdMergedSet1);
+    expectedResponse.put("Label2", nodeIdMergedSet2);
+
+    GetLabelsToNodesResponse response = RouterYarnClientUtils.
+        mergeLabelsToNodes(responses);
+
+    Assert.assertEquals(expectedResponse, response.getLabelsToNodes());
+  }
 }