瀏覽代碼

HADOOP-16245. Restrict the effect of LdapGroupsMapping SSL configurations to avoid interfering with other SSL connections. Contributed by Erik Krogen.

Erik Krogen 6 年之前
父節點
當前提交
62efb63006

+ 153 - 13
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/LdapGroupsMapping.java

@@ -17,12 +17,18 @@
  */
 package org.apache.hadoop.security;
 
+import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.Reader;
+import java.net.InetAddress;
+import java.net.Socket;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Paths;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Hashtable;
@@ -44,6 +50,13 @@ import javax.naming.directory.SearchResult;
 import javax.naming.ldap.LdapName;
 import javax.naming.ldap.Rdn;
 import javax.naming.spi.InitialContextFactory;
+import javax.net.SocketFactory;
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
 
 import com.google.common.collect.Iterators;
 import org.apache.hadoop.classification.InterfaceAudience;
@@ -273,6 +286,13 @@ public class LdapGroupsMapping
   public static final String LDAP_CTX_FACTORY_CLASS_DEFAULT =
       "com.sun.jndi.ldap.LdapCtxFactory";
 
+  /**
+   * The env key used for specifying a custom socket factory to be used for
+   * creating connections to the LDAP server. This is not a Hadoop conf key.
+   */
+  private static final String LDAP_SOCKET_FACTORY_ENV_KEY =
+      "java.naming.ldap.factory.socket";
+
   private static final Logger LOG =
       LoggerFactory.getLogger(LdapGroupsMapping.class);
 
@@ -640,19 +660,13 @@ public class LdapGroupsMapping
       // Set up SSL security, if necessary
       if (useSsl) {
         env.put(Context.SECURITY_PROTOCOL, "ssl");
-        if (!keystore.isEmpty()) {
-          System.setProperty("javax.net.ssl.keyStore", keystore);
-        }
-        if (!keystorePass.isEmpty()) {
-          System.setProperty("javax.net.ssl.keyStorePassword", keystorePass);
-        }
-        if (!truststore.isEmpty()) {
-          System.setProperty("javax.net.ssl.trustStore", truststore);
-        }
-        if (!truststorePass.isEmpty()) {
-          System.setProperty("javax.net.ssl.trustStorePassword",
-              truststorePass);
-        }
+        // It is necessary to use a custom socket factory rather than setting
+        // system properties to configure these options to avoid interfering
+        // with other SSL factories throughout the system
+        LdapSslSocketFactory.setConfigurations(keystore, keystorePass,
+            truststore, truststorePass);
+        env.put("java.naming.ldap.factory.socket",
+            LdapSslSocketFactory.class.getName());
       }
 
       env.put(Context.SECURITY_PRINCIPAL, currentBindUser.username);
@@ -929,4 +943,130 @@ public class LdapGroupsMapping
       return this.username;
     }
   }
+
+  /**
+   * An private internal socket factory used to create SSL sockets with custom
+   * configuration. There is no way to pass a specific instance of a factory to
+   * the Java naming services, and the instantiated socket factory is not
+   * passed any contextual information, so all information must be encapsulated
+   * directly in the class. Static fields are used here to achieve this. This is
+   * safe since the only usage of {@link LdapGroupsMapping} is within
+   * {@link Groups}, which is a singleton (see the GROUPS field).
+   * <p>
+   * This has nearly the same behavior as an {@link SSLSocketFactory}. The only
+   * additional logic is to configure the key store and trust store.
+   * <p>
+   * This is public only to be accessible by the Java naming services.
+   */
+  @InterfaceAudience.Private
+  public static class LdapSslSocketFactory extends SocketFactory {
+
+    /** Cached value lazy-loaded by {@link #getDefault()}. */
+    private static LdapSslSocketFactory defaultSslFactory;
+
+    private static String keyStoreLocation;
+    private static String keyStorePassword;
+    private static String trustStoreLocation;
+    private static String trustStorePassword;
+
+    private final SSLSocketFactory socketFactory;
+
+    LdapSslSocketFactory(SSLSocketFactory wrappedSocketFactory) {
+      this.socketFactory = wrappedSocketFactory;
+    }
+
+    public static synchronized SocketFactory getDefault() {
+      if (defaultSslFactory == null) {
+        try {
+          SSLContext context = SSLContext.getInstance("TLS");
+          context.init(createKeyManagers(), createTrustManagers(), null);
+          defaultSslFactory =
+              new LdapSslSocketFactory(context.getSocketFactory());
+          LOG.info("Successfully instantiated LdapSslSocketFactory with "
+                  + "keyStoreLocation = {} and trustStoreLocation = {}",
+              keyStoreLocation, trustStoreLocation);
+        } catch (IOException | GeneralSecurityException e) {
+          throw new RuntimeException("Unable to create SSLSocketFactory", e);
+        }
+      }
+      return defaultSslFactory;
+    }
+
+    static synchronized void setConfigurations(String newKeyStoreLocation,
+        String newKeyStorePassword, String newTrustStoreLocation,
+        String newTrustStorePassword) {
+      LdapSslSocketFactory.keyStoreLocation = newKeyStoreLocation;
+      LdapSslSocketFactory.keyStorePassword = newKeyStorePassword;
+      LdapSslSocketFactory.trustStoreLocation = newTrustStoreLocation;
+      LdapSslSocketFactory.trustStorePassword = newTrustStorePassword;
+    }
+
+    private static KeyManager[] createKeyManagers()
+        throws IOException, GeneralSecurityException {
+      if (keyStoreLocation.isEmpty()) {
+        return null;
+      }
+      KeyManagerFactory keyMgrFactory = KeyManagerFactory
+          .getInstance(KeyManagerFactory.getDefaultAlgorithm());
+      keyMgrFactory.init(createKeyStore(keyStoreLocation, keyStorePassword),
+          getPasswordCharArray(keyStorePassword));
+      return keyMgrFactory.getKeyManagers();
+    }
+
+    private static TrustManager[] createTrustManagers()
+        throws IOException, GeneralSecurityException {
+      if (trustStoreLocation.isEmpty()) {
+        return null;
+      }
+      TrustManagerFactory trustMgrFactory = TrustManagerFactory
+          .getInstance(TrustManagerFactory.getDefaultAlgorithm());
+      trustMgrFactory.init(
+          createKeyStore(trustStoreLocation, trustStorePassword));
+      return trustMgrFactory.getTrustManagers();
+    }
+
+    private static KeyStore createKeyStore(String location, String password)
+        throws IOException, GeneralSecurityException {
+      KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
+      try (InputStream keyStoreInput = new FileInputStream(location)) {
+        keyStore.load(keyStoreInput, getPasswordCharArray(password));
+      }
+      return keyStore;
+    }
+
+    private static char[] getPasswordCharArray(String password) {
+      if (password == null || password.isEmpty()) {
+        return null;
+      }
+      return password.toCharArray();
+    }
+
+    @Override
+    public Socket createSocket() throws IOException {
+      return socketFactory.createSocket();
+    }
+
+    @Override
+    public Socket createSocket(String host, int port) throws IOException {
+      return socketFactory.createSocket(host, port);
+    }
+
+    @Override
+    public Socket createSocket(String host, int port, InetAddress localHost,
+        int localPort) throws IOException {
+      return socketFactory.createSocket(host, port, localHost, localPort);
+    }
+
+    @Override
+    public Socket createSocket(InetAddress host, int port) throws IOException {
+      return socketFactory.createSocket(host, port);
+    }
+
+    @Override
+    public Socket createSocket(InetAddress address, int port,
+        InetAddress localAddress, int localPort) throws IOException {
+      return socketFactory.createSocket(address, port, localAddress, localPort);
+    }
+  }
+
 }