Przeglądaj źródła

HADOOP-14578. Bind IPC connections to kerberos UPN host for proxy users.
Contributed by Daryn Sharp.

Kihwal Lee 7 lat temu
rodzic
commit
e15a08fc9a

+ 21 - 15
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java

@@ -634,7 +634,8 @@ public class Client implements AutoCloseable {
       return false;
     }
     
-    private synchronized void setupConnection() throws IOException {
+    private synchronized void setupConnection(
+        UserGroupInformation ticket) throws IOException {
       short ioFailures = 0;
       short timeoutFailures = 0;
       while (true) {
@@ -662,23 +663,25 @@ public class Client implements AutoCloseable {
            * client, to ensure Server matching address of the client connection
            * to host name in principal passed.
            */
-          UserGroupInformation ticket = remoteId.getTicket();
+          InetSocketAddress bindAddr = null;
           if (ticket != null && ticket.hasKerberosCredentials()) {
             KerberosInfo krbInfo = 
               remoteId.getProtocol().getAnnotation(KerberosInfo.class);
-            if (krbInfo != null && krbInfo.clientPrincipal() != null) {
-              String host = 
-                SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName());
-              
+            if (krbInfo != null) {
+              String principal = ticket.getUserName();
+              String host = SecurityUtil.getHostFromPrincipal(principal);
               // If host name is a valid local address then bind socket to it
               InetAddress localAddr = NetUtils.getLocalInetAddress(host);
               if (localAddr != null) {
-                this.socket.bind(new InetSocketAddress(localAddr, 0));
+                if (LOG.isDebugEnabled()) {
+                  LOG.debug("Binding " + principal + " to " + localAddr);
+                }
+                bindAddr = new InetSocketAddress(localAddr, 0);
               }
             }
           }
-          
-          NetUtils.connect(this.socket, server, connectionTimeout);
+
+          NetUtils.connect(this.socket, server, bindAddr, connectionTimeout);
           this.socket.setSoTimeout(soTimeout);
           return;
         } catch (ConnectTimeoutException toe) {
@@ -762,7 +765,14 @@ public class Client implements AutoCloseable {
         AtomicBoolean fallbackToSimpleAuth) {
       if (socket != null || shouldCloseConnection.get()) {
         return;
-      } 
+      }
+      UserGroupInformation ticket = remoteId.getTicket();
+      if (ticket != null) {
+        final UserGroupInformation realUser = ticket.getRealUser();
+        if (realUser != null) {
+          ticket = realUser;
+        }
+      }
       try {
         if (LOG.isDebugEnabled()) {
           LOG.debug("Connecting to "+server);
@@ -774,14 +784,10 @@ public class Client implements AutoCloseable {
         short numRetries = 0;
         Random rand = null;
         while (true) {
-          setupConnection();
+          setupConnection(ticket);
           ipcStreams = new IpcStreams(socket, maxResponseLength);
           writeConnectionHeader(ipcStreams);
           if (authProtocol == AuthProtocol.SASL) {
-            UserGroupInformation ticket = remoteId.getTicket();
-            if (ticket.getRealUser() != null) {
-              ticket = ticket.getRealUser();
-            }
             try {
               authMethod = ticket
                   .doAs(new PrivilegedExceptionAction<AuthMethod>() {

+ 76 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java

@@ -39,9 +39,12 @@ import java.io.InputStream;
 import java.io.OutputStream;
 import java.lang.reflect.Method;
 import java.lang.reflect.Proxy;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketAddress;
+import java.net.SocketException;
 import java.net.SocketTimeoutException;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -79,6 +82,7 @@ import org.apache.hadoop.ipc.Server.Connection;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
 import org.apache.hadoop.net.ConnectTimeoutException;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.KerberosInfo;
 import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
@@ -1486,6 +1490,78 @@ public class TestIPC {
     Assert.fail("didn't get limit exceeded");
   }
 
+  @Test
+  public void testUserBinding() throws Exception {
+    checkUserBinding(false);
+  }
+
+  @Test
+  public void testProxyUserBinding() throws Exception {
+    checkUserBinding(true);
+  }
+
+  private void checkUserBinding(boolean asProxy) throws Exception {
+    Socket s;
+    // don't attempt bind with no service host.
+    s = checkConnect(null, asProxy);
+    Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
+
+    // don't attempt bind with service host not belonging to this host.
+    s = checkConnect("1.2.3.4", asProxy);
+    Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
+
+    // do attempt bind when service host is this host.
+    InetAddress addr = InetAddress.getLocalHost();
+    s = checkConnect(addr.getHostAddress(), asProxy);
+    Mockito.verify(s).bind(new InetSocketAddress(addr, 0));
+  }
+
+  // dummy protocol that claims to support kerberos.
+  @KerberosInfo(serverPrincipal = "server@REALM")
+  private static class TestBindingProtocol {
+  }
+
+  private Socket checkConnect(String addr, boolean asProxy) throws Exception {
+    // create a fake ugi that claims to have kerberos credentials.
+    StringBuilder principal = new StringBuilder();
+    principal.append("client");
+    if (addr != null) {
+      principal.append("/").append(addr);
+    }
+    principal.append("@REALM");
+    UserGroupInformation ugi =
+        spy(UserGroupInformation.createRemoteUser(principal.toString()));
+    Mockito.doReturn(true).when(ugi).hasKerberosCredentials();
+    if (asProxy) {
+      ugi = UserGroupInformation.createProxyUser("proxy", ugi);
+    }
+
+    // create a mock socket that throws on connect.
+    SocketException expectedConnectEx =
+        new SocketException("Expected connect failure");
+    Socket s = Mockito.mock(Socket.class);
+    SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
+    Mockito.doReturn(s).when(mockFactory).createSocket();
+    doThrow(expectedConnectEx).when(s).connect(
+        Mockito.any(SocketAddress.class), Mockito.anyInt());
+
+    // do a dummy call and expect it to throw an exception on connect.
+    // tests should verify if/how a bind occurred.
+    try (Client client = new Client(LongWritable.class, conf, mockFactory)) {
+      final InetSocketAddress sockAddr = new InetSocketAddress(0);
+      final LongWritable param = new LongWritable(RANDOM.nextLong());
+      final ConnectionId remoteId = new ConnectionId(
+          sockAddr, TestBindingProtocol.class, ugi, 0,
+          RetryPolicies.TRY_ONCE_THEN_FAIL, conf);
+      client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, null);
+      fail("call didn't throw connect exception");
+    } catch (SocketException se) {
+      // ipc layer re-wraps exceptions, so check the cause.
+      Assert.assertSame(expectedConnectEx, se.getCause());
+    }
+    return s;
+  }
+
   private void doIpcVersionTest(
       byte[] requestData,
       byte[] expectedResponse) throws IOException {