浏览代码

ZOOKEEPER-3731: Disallow HTTP TRACE method on PrometheusMetrics Server (#1682)

doxsch 2 年之前
父节点
当前提交
f46b8fb87f

+ 23 - 0
zookeeper-metrics-providers/zookeeper-prometheus-metrics/src/main/java/org/apache/zookeeper/metrics/prometheus/PrometheusMetricsProvider.java

@@ -52,9 +52,12 @@ import org.apache.zookeeper.metrics.MetricsProviderLifeCycleException;
 import org.apache.zookeeper.metrics.Summary;
 import org.apache.zookeeper.metrics.SummarySet;
 import org.apache.zookeeper.server.RateLogger;
+import org.eclipse.jetty.security.ConstraintMapping;
+import org.eclipse.jetty.security.ConstraintSecurityHandler;
 import org.eclipse.jetty.server.Server;
 import org.eclipse.jetty.servlet.ServletContextHandler;
 import org.eclipse.jetty.servlet.ServletHolder;
+import org.eclipse.jetty.util.security.Constraint;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -134,6 +137,7 @@ public class PrometheusMetricsProvider implements MetricsProvider {
             server = new Server(new InetSocketAddress(host, port));
             ServletContextHandler context = new ServletContextHandler();
             context.setContextPath("/");
+            constrainTraceMethod(context);
             server.setHandler(context);
             context.addServlet(new ServletHolder(servlet), "/metrics");
             server.start();
@@ -235,6 +239,25 @@ public class PrometheusMetricsProvider implements MetricsProvider {
         // not supported on Prometheus
     }
 
+    /**
+     * Add constraint to a given context to disallow TRACE method.
+     * @param ctxHandler the context to modify
+     */
+    private void constrainTraceMethod(ServletContextHandler ctxHandler) {
+        Constraint c = new Constraint();
+        c.setAuthenticate(true);
+
+        ConstraintMapping cmt = new ConstraintMapping();
+        cmt.setConstraint(c);
+        cmt.setMethod("TRACE");
+        cmt.setPathSpec("/*");
+
+        ConstraintSecurityHandler securityHandler = new ConstraintSecurityHandler();
+        securityHandler.setConstraintMappings(new ConstraintMapping[] {cmt});
+
+        ctxHandler.setSecurityHandler(securityHandler);
+    }
+
     private class Context implements MetricsContext {
 
         private final ConcurrentMap<String, PrometheusGaugeWrapper> gauges = new ConcurrentHashMap<>();

+ 24 - 0
zookeeper-metrics-providers/zookeeper-prometheus-metrics/src/test/java/org/apache/zookeeper/metrics/prometheus/PrometheusMetricsProviderTest.java

@@ -31,6 +31,9 @@ import io.prometheus.client.CollectorRegistry;
 import java.io.IOException;
 import java.io.PrintWriter;
 import java.io.StringWriter;
+import java.lang.reflect.Field;
+import java.net.HttpURLConnection;
+import java.net.URL;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -50,7 +53,10 @@ import org.apache.zookeeper.metrics.MetricsContext;
 import org.apache.zookeeper.metrics.Summary;
 import org.apache.zookeeper.metrics.SummarySet;
 import org.apache.zookeeper.server.util.QuotaMetricsUtils;
+import org.eclipse.jetty.server.Server;
+import org.eclipse.jetty.server.ServerConnector;
 import org.hamcrest.CoreMatchers;
+import org.junit.Assert;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -61,6 +67,7 @@ import org.junit.jupiter.api.Test;
  */
 public class PrometheusMetricsProviderTest {
 
+    private static final String URL_FORMAT = "http://localhost:%d/metrics";
     private PrometheusMetricsProvider provider;
 
     @BeforeEach
@@ -395,6 +402,23 @@ public class PrometheusMetricsProviderTest {
         assertThat(res, CoreMatchers.containsString("cc{quantile=\"0.99\",} 10.0"));
     }
 
+    /**
+     * Using TRACE method to visit metrics provider, the response should be 403 forbidden.
+     */
+    @Test
+    public void testTraceCall() throws IOException, IllegalAccessException, NoSuchFieldException {
+        Field privateServerField = provider.getClass().getDeclaredField("server");
+        privateServerField.setAccessible(true);
+        Server server = (Server) privateServerField.get(provider);
+        int port = ((ServerConnector) server.getConnectors()[0]).getLocalPort();
+
+        String metricsUrl = String.format(URL_FORMAT, port);
+        HttpURLConnection conn = (HttpURLConnection) new URL(metricsUrl).openConnection();
+        conn.setRequestMethod("TRACE");
+        conn.connect();
+        Assert.assertEquals(HttpURLConnection.HTTP_FORBIDDEN, conn.getResponseCode());
+    }
+
     @Test
     public void testSummary_asyncAndExceedMaxQueueSize() throws Exception {
         final Properties config = new Properties();