|
@@ -22,18 +22,41 @@ import java.io.IOException;
|
|
|
import java.io.PrintWriter;
|
|
|
import java.io.StringWriter;
|
|
|
import java.net.HttpURLConnection;
|
|
|
-import java.util.*;
|
|
|
+import java.util.Set;
|
|
|
+import java.util.HashSet;
|
|
|
+import java.util.Enumeration;
|
|
|
+import java.util.Collection;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
|
|
|
-import javax.servlet.*;
|
|
|
+import javax.servlet.FilterConfig;
|
|
|
+import javax.servlet.FilterChain;
|
|
|
+import javax.servlet.Filter;
|
|
|
+import javax.servlet.ServletContext;
|
|
|
+import javax.servlet.ServletResponse;
|
|
|
+import javax.servlet.ServletRequest;
|
|
|
+import javax.servlet.ServletException;
|
|
|
import javax.servlet.http.Cookie;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
|
|
|
-import static org.junit.Assert.*;
|
|
|
+import static org.junit.Assert.assertTrue;
|
|
|
+import static org.junit.Assert.assertFalse;
|
|
|
+import static org.junit.Assert.assertEquals;
|
|
|
+import static org.junit.Assert.fail;
|
|
|
|
|
|
+import org.apache.hadoop.http.TestHttpServer;
|
|
|
import org.apache.hadoop.yarn.server.webproxy.ProxyUtils;
|
|
|
import org.apache.hadoop.yarn.server.webproxy.WebAppProxyServlet;
|
|
|
+import org.eclipse.jetty.server.Server;
|
|
|
+import org.eclipse.jetty.server.ServerConnector;
|
|
|
+import org.eclipse.jetty.servlet.ServletContextHandler;
|
|
|
+import org.eclipse.jetty.servlet.ServletHolder;
|
|
|
+import org.eclipse.jetty.util.thread.QueuedThreadPool;
|
|
|
import org.glassfish.grizzly.servlet.HttpServletResponseImpl;
|
|
|
import org.junit.Test;
|
|
|
import org.mockito.Mockito;
|
|
@@ -121,6 +144,47 @@ public class TestAmFilter {
|
|
|
filter.destroy();
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void testFindRedirectUrl() throws Exception {
|
|
|
+ final String rm1 = "rm1";
|
|
|
+ final String rm2 = "rm2";
|
|
|
+ // generate a valid URL
|
|
|
+ final String rm1Url = startHttpServer();
|
|
|
+ // invalid url
|
|
|
+ final String rm2Url = "host2:8088";
|
|
|
+
|
|
|
+ TestAmIpFilter filter = new TestAmIpFilter();
|
|
|
+ TestAmIpFilter spy = Mockito.spy(filter);
|
|
|
+ // make sure findRedirectUrl() go to HA branch
|
|
|
+ spy.proxyUriBases = new HashMap<>();
|
|
|
+ spy.proxyUriBases.put(rm1, rm1Url);
|
|
|
+ spy.proxyUriBases.put(rm2, rm2Url);
|
|
|
+
|
|
|
+ Collection<String> rmIds = new ArrayList<>(Arrays.asList(rm1, rm2));
|
|
|
+ Mockito.doReturn(rmIds).when(spy).getRmIds(Mockito.any());
|
|
|
+ Mockito.doReturn(rm1Url).when(spy)
|
|
|
+ .getUrlByRmId(Mockito.any(), Mockito.eq(rm2));
|
|
|
+ Mockito.doReturn(rm2Url).when(spy)
|
|
|
+ .getUrlByRmId(Mockito.any(), Mockito.eq(rm1));
|
|
|
+
|
|
|
+ assertEquals(spy.findRedirectUrl(), rm1Url);
|
|
|
+ }
|
|
|
+
|
|
|
+ private String startHttpServer() throws Exception {
|
|
|
+ Server server = new Server(0);
|
|
|
+ ((QueuedThreadPool)server.getThreadPool()).setMaxThreads(10);
|
|
|
+ ServletContextHandler context = new ServletContextHandler();
|
|
|
+ context.setContextPath("/foo");
|
|
|
+ server.setHandler(context);
|
|
|
+ String servletPath = "/bar";
|
|
|
+ context.addServlet(new ServletHolder(TestHttpServer.EchoServlet.class),
|
|
|
+ servletPath);
|
|
|
+ ((ServerConnector)server.getConnectors()[0]).setHost("localhost");
|
|
|
+ server.start();
|
|
|
+ System.setProperty("sun.net.http.allowRestrictedHeaders", "true");
|
|
|
+ return server.getURI().toString() + servletPath;
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Test AmIpFilter
|
|
|
*/
|