浏览代码

HADOOP-6419. Adds SASL based authentication to RPC. Contributed by Kan Zhang.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@905860 13f79535-47bb-0310-9956-ffa450edef68
Devaraj Das 15 年之前
父节点
当前提交
940389afce

+ 3 - 0
CHANGES.txt

@@ -48,6 +48,9 @@ Trunk (unreleased changes)
     upon login. The tokens are read from a file specified in the
     upon login. The tokens are read from a file specified in the
     environment variable. (ddas)
     environment variable. (ddas)
 
 
+    HADOOP-6419. Adds SASL based authentication to RPC.
+    (Kan Zhang via ddas)
+
   IMPROVEMENTS
   IMPROVEMENTS
 
 
     HADOOP-6283. Improve the exception messages thrown by
     HADOOP-6283. Improve the exception messages thrown by

+ 7 - 2
src/java/org/apache/hadoop/ipc/AvroRpcEngine.java

@@ -33,6 +33,8 @@ import org.apache.commons.logging.*;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.net.NetUtils;
 
 
 import org.apache.avro.*;
 import org.apache.avro.*;
@@ -192,10 +194,13 @@ class AvroRpcEngine implements RpcEngine {
    * port and address. */
    * port and address. */
   public RPC.Server getServer(Class iface, Object impl, String bindAddress,
   public RPC.Server getServer(Class iface, Object impl, String bindAddress,
                               int port, int numHandlers, boolean verbose,
                               int port, int numHandlers, boolean verbose,
-                              Configuration conf) throws IOException {
+                              Configuration conf, 
+                       SecretManager<? extends TokenIdentifier> secretManager
+                              ) throws IOException {
     return ENGINE.getServer(TunnelProtocol.class,
     return ENGINE.getServer(TunnelProtocol.class,
                             new TunnelResponder(iface, impl),
                             new TunnelResponder(iface, impl),
-                            bindAddress, port, numHandlers, verbose, conf);
+                            bindAddress, port, numHandlers, verbose, conf, 
+                            secretManager);
   }
   }
 
 
 }
 }

+ 98 - 13
src/java/org/apache/hadoop/ipc/Client.java

@@ -31,7 +31,9 @@ import java.io.BufferedInputStream;
 import java.io.BufferedOutputStream;
 import java.io.BufferedOutputStream;
 import java.io.FilterInputStream;
 import java.io.FilterInputStream;
 import java.io.InputStream;
 import java.io.InputStream;
+import java.io.OutputStream;
 
 
+import java.security.PrivilegedExceptionAction;
 import java.util.Hashtable;
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.Iterator;
 import java.util.Map.Entry;
 import java.util.Map.Entry;
@@ -44,11 +46,19 @@ import org.apache.commons.logging.*;
 
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.IOUtils;
 import org.apache.hadoop.io.IOUtils;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.KerberosInfo;
+import org.apache.hadoop.security.SaslRpcClient;
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.security.token.TokenSelector;
+import org.apache.hadoop.security.token.TokenInfo;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.ReflectionUtils;
 
 
 /** A client for an IPC service.  IPC calls take a single {@link Writable} as a
 /** A client for an IPC service.  IPC calls take a single {@link Writable} as a
@@ -196,8 +206,13 @@ public class Client {
    * socket: responses may be delivered out of order. */
    * socket: responses may be delivered out of order. */
   private class Connection extends Thread {
   private class Connection extends Thread {
     private InetSocketAddress server;             // server ip:port
     private InetSocketAddress server;             // server ip:port
+    private String serverPrincipal;  // server's krb5 principal name
     private ConnectionHeader header;              // connection header
     private ConnectionHeader header;              // connection header
-    private ConnectionId remoteId;                // connection id
+    private final ConnectionId remoteId;                // connection id
+    private final AuthMethod authMethod; // authentication method
+    private final boolean useSasl;
+    private Token<? extends TokenIdentifier> token;
+    private SaslRpcClient saslRpcClient;
     
     
     private Socket socket = null;                 // connected socket
     private Socket socket = null;                 // connected socket
     private DataInputStream in;
     private DataInputStream in;
@@ -221,6 +236,42 @@ public class Client {
       Class<?> protocol = remoteId.getProtocol();
       Class<?> protocol = remoteId.getProtocol();
       header = 
       header = 
         new ConnectionHeader(protocol == null ? null : protocol.getName(), ticket);
         new ConnectionHeader(protocol == null ? null : protocol.getName(), ticket);
+      this.useSasl = UserGroupInformation.isSecurityEnabled();
+      if (useSasl && protocol != null) {
+        TokenInfo tokenInfo = protocol.getAnnotation(TokenInfo.class);
+        if (tokenInfo != null) {
+          TokenSelector<? extends TokenIdentifier> tokenSelector = null;
+          try {
+            tokenSelector = tokenInfo.value().newInstance();
+          } catch (InstantiationException e) {
+            throw new IOException(e.toString());
+          } catch (IllegalAccessException e) {
+            throw new IOException(e.toString());
+          }
+          InetSocketAddress addr = remoteId.getAddress();
+          token = tokenSelector.selectToken(new Text(addr.getAddress()
+              .getHostAddress() + ":" + addr.getPort()), 
+              ticket.getTokens());
+        }
+        KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
+        if (krbInfo != null) {
+          String serverKey = krbInfo.value();
+          if (serverKey != null) {
+            serverPrincipal = conf.get(serverKey);
+          }
+        }
+      }
+      
+      if (!useSasl) {
+        authMethod = AuthMethod.SIMPLE;
+      } else if (token != null) {
+        authMethod = AuthMethod.DIGEST;
+      } else {
+        authMethod = AuthMethod.KERBEROS;
+      }
+      if (LOG.isDebugEnabled())
+        LOG.debug("Use " + authMethod + " authentication for protocol "
+            + protocol.getSimpleName());
       
       
       this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " +
       this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " +
           remoteId.getAddress().toString() +
           remoteId.getAddress().toString() +
@@ -302,11 +353,20 @@ public class Client {
       }
       }
     }
     }
     
     
+    private synchronized void disposeSasl() {
+      if (saslRpcClient != null) {
+        try {
+          saslRpcClient.dispose();
+        } catch (IOException ignored) {
+        }
+      }
+    }
+    
     /** Connect to the server and set up the I/O streams. It then sends
     /** Connect to the server and set up the I/O streams. It then sends
      * a header to the server and starts
      * a header to the server and starts
      * the connection thread that waits for responses.
      * the connection thread that waits for responses.
      */
      */
-    private synchronized void setupIOstreams() {
+    private synchronized void setupIOstreams() throws InterruptedException {
       if (socket != null || shouldCloseConnection.get()) {
       if (socket != null || shouldCloseConnection.get()) {
         return;
         return;
       }
       }
@@ -334,15 +394,33 @@ public class Client {
             handleConnectionFailure(ioFailures++, maxRetries, ie);
             handleConnectionFailure(ioFailures++, maxRetries, ie);
           }
           }
         }
         }
+        InputStream inStream = NetUtils.getInputStream(socket);
+        OutputStream outStream = NetUtils.getOutputStream(socket);
+        writeRpcHeader(outStream);
+        if (useSasl) {
+          final InputStream in2 = inStream;
+          final OutputStream out2 = outStream;
+          remoteId.getTicket().doAs(new PrivilegedExceptionAction<Object>() {
+            @Override
+            public Object run() throws IOException {
+              saslRpcClient = new SaslRpcClient(authMethod, token,
+                  serverPrincipal);
+              saslRpcClient.saslConnect(in2, out2);
+              return null;
+            }
+          });
+          inStream = saslRpcClient.getInputStream(inStream);
+          outStream = saslRpcClient.getOutputStream(outStream);
+        }
         if (doPing) {
         if (doPing) {
           this.in = new DataInputStream(new BufferedInputStream
           this.in = new DataInputStream(new BufferedInputStream
-            (new PingInputStream(NetUtils.getInputStream(socket))));
+            (new PingInputStream(inStream)));
         } else {
         } else {
           this.in = new DataInputStream(new BufferedInputStream
           this.in = new DataInputStream(new BufferedInputStream
-            (NetUtils.getInputStream(socket)));
+            (inStream));
         }
         }
         this.out = new DataOutputStream
         this.out = new DataOutputStream
-            (new BufferedOutputStream(NetUtils.getOutputStream(socket)));
+            (new BufferedOutputStream(outStream));
         writeHeader();
         writeHeader();
 
 
         // update last activity time
         // update last activity time
@@ -396,14 +474,20 @@ public class Client {
           ". Already tried " + curRetries + " time(s).");
           ". Already tried " + curRetries + " time(s).");
     }
     }
 
 
-    /* Write the header for each connection
+    /* Write the RPC header */
+    private void writeRpcHeader(OutputStream outStream) throws IOException {
+      DataOutputStream out = new DataOutputStream(new BufferedOutputStream(outStream));
+      // Write out the header, version and authentication method
+      out.write(Server.HEADER.array());
+      out.write(Server.CURRENT_VERSION);
+      authMethod.write(out);
+      out.flush();
+    }
+    
+    /* Write the protocol header for each connection
      * Out is not synchronized because only the first thread does this.
      * Out is not synchronized because only the first thread does this.
      */
      */
     private void writeHeader() throws IOException {
     private void writeHeader() throws IOException {
-      // Write out the header and version
-      out.write(Server.HEADER.array());
-      out.write(Server.CURRENT_VERSION);
-
       // Write out the ConnectionHeader
       // Write out the ConnectionHeader
       DataOutputBuffer buf = new DataOutputBuffer();
       DataOutputBuffer buf = new DataOutputBuffer();
       header.write(buf);
       header.write(buf);
@@ -575,6 +659,7 @@ public class Client {
       // close the streams and therefore the socket
       // close the streams and therefore the socket
       IOUtils.closeStream(out);
       IOUtils.closeStream(out);
       IOUtils.closeStream(in);
       IOUtils.closeStream(in);
+      disposeSasl();
 
 
       // clean up all calls
       // clean up all calls
       if (closeException == null) {
       if (closeException == null) {
@@ -815,7 +900,7 @@ public class Client {
    */
    */
   @Deprecated
   @Deprecated
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
-    throws IOException {
+    throws IOException, InterruptedException {
     return call(params, addresses, null, null);
     return call(params, addresses, null, null);
   }
   }
   
   
@@ -825,7 +910,7 @@ public class Client {
    * contains nulls for calls that timed out or errored.  */
    * contains nulls for calls that timed out or errored.  */
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
                          Class<?> protocol, UserGroupInformation ticket)
                          Class<?> protocol, UserGroupInformation ticket)
-    throws IOException {
+    throws IOException, InterruptedException {
     if (addresses.length == 0) return new Writable[0];
     if (addresses.length == 0) return new Writable[0];
 
 
     ParallelResults results = new ParallelResults(params.length);
     ParallelResults results = new ParallelResults(params.length);
@@ -859,7 +944,7 @@ public class Client {
                                    Class<?> protocol,
                                    Class<?> protocol,
                                    UserGroupInformation ticket,
                                    UserGroupInformation ticket,
                                    Call call)
                                    Call call)
-                                   throws IOException {
+                                   throws IOException, InterruptedException {
     if (!running.get()) {
     if (!running.get()) {
       // the client is stopped
       // the client is stopped
       throw new IOException("The client is stopped");
       throw new IOException("The client is stopped");

+ 26 - 8
src/java/org/apache/hadoop/ipc/RPC.java

@@ -37,6 +37,8 @@ import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.authorize.AuthorizationException;
 import org.apache.hadoop.security.authorize.AuthorizationException;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
 import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.ReflectionUtils;
@@ -254,7 +256,7 @@ public class RPC {
   @Deprecated
   @Deprecated
   public static Object[] call(Method method, Object[][] params,
   public static Object[] call(Method method, Object[][] params,
                               InetSocketAddress[] addrs, Configuration conf)
                               InetSocketAddress[] addrs, Configuration conf)
-    throws IOException {
+    throws IOException, InterruptedException {
     return call(method, params, addrs, null, conf);
     return call(method, params, addrs, null, conf);
   }
   }
   
   
@@ -262,7 +264,7 @@ public class RPC {
   public static Object[] call(Method method, Object[][] params,
   public static Object[] call(Method method, Object[][] params,
                               InetSocketAddress[] addrs, 
                               InetSocketAddress[] addrs, 
                               UserGroupInformation ticket, Configuration conf)
                               UserGroupInformation ticket, Configuration conf)
-    throws IOException {
+    throws IOException, InterruptedException {
 
 
     return getProtocolEngine(method.getDeclaringClass(), conf)
     return getProtocolEngine(method.getDeclaringClass(), conf)
       .call(method, params, addrs, ticket, conf);
       .call(method, params, addrs, ticket, conf);
@@ -288,7 +290,7 @@ public class RPC {
                                  final boolean verbose, Configuration conf) 
                                  final boolean verbose, Configuration conf) 
     throws IOException {
     throws IOException {
     return getServer(instance.getClass(),         // use impl class for protocol
     return getServer(instance.getClass(),         // use impl class for protocol
-                     instance, bindAddress, port, numHandlers, false, conf);
+                     instance, bindAddress, port, numHandlers, false, conf, null);
   }
   }
 
 
   /** Construct a server for a protocol implementation instance. */
   /** Construct a server for a protocol implementation instance. */
@@ -296,19 +298,34 @@ public class RPC {
                                  Object instance, String bindAddress,
                                  Object instance, String bindAddress,
                                  int port, Configuration conf) 
                                  int port, Configuration conf) 
     throws IOException {
     throws IOException {
-    return getServer(protocol, instance, bindAddress, port, 1, false, conf);
+    return getServer(protocol, instance, bindAddress, port, 1, false, conf, null);
   }
   }
 
 
-  /** Construct a server for a protocol implementation instance. */
+  /** Construct a server for a protocol implementation instance.
+   * @deprecated secretManager should be passed.
+   */
+  @Deprecated
   public static Server getServer(Class protocol,
   public static Server getServer(Class protocol,
                                  Object instance, String bindAddress, int port,
                                  Object instance, String bindAddress, int port,
                                  int numHandlers,
                                  int numHandlers,
                                  boolean verbose, Configuration conf) 
                                  boolean verbose, Configuration conf) 
     throws IOException {
     throws IOException {
     
     
+    return getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
+                 conf, null);
+  }
+  
+  /** Construct a server for a protocol implementation instance. */
+  public static Server getServer(Class<?> protocol,
+                                 Object instance, String bindAddress, int port,
+                                 int numHandlers,
+                                 boolean verbose, Configuration conf,
+                                 SecretManager<? extends TokenIdentifier> secretManager) 
+    throws IOException {
+    
     return getProtocolEngine(protocol, conf)
     return getProtocolEngine(protocol, conf)
       .getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
       .getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
-                 conf);
+                 conf, secretManager);
   }
   }
 
 
   /** An RPC Server. */
   /** An RPC Server. */
@@ -316,8 +333,9 @@ public class RPC {
   
   
     protected Server(String bindAddress, int port, 
     protected Server(String bindAddress, int port, 
                      Class<? extends Writable> paramClass, int handlerCount, 
                      Class<? extends Writable> paramClass, int handlerCount, 
-                     Configuration conf, String serverName) throws IOException {
-      super(bindAddress, port, paramClass, handlerCount, conf, serverName);
+                     Configuration conf, String serverName, 
+                     SecretManager<? extends TokenIdentifier> secretManager) throws IOException {
+      super(bindAddress, port, paramClass, handlerCount, conf, serverName, secretManager);
     }
     }
   }
   }
 
 

+ 6 - 2
src/java/org/apache/hadoop/ipc/RpcEngine.java

@@ -24,6 +24,8 @@ import java.net.InetSocketAddress;
 import javax.net.SocketFactory;
 import javax.net.SocketFactory;
 
 
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 
 
 /** An RPC implementation. */
 /** An RPC implementation. */
@@ -41,11 +43,13 @@ interface RpcEngine {
   /** Expert: Make multiple, parallel calls to a set of servers. */
   /** Expert: Make multiple, parallel calls to a set of servers. */
   Object[] call(Method method, Object[][] params, InetSocketAddress[] addrs,
   Object[] call(Method method, Object[][] params, InetSocketAddress[] addrs,
                 UserGroupInformation ticket, Configuration conf)
                 UserGroupInformation ticket, Configuration conf)
-    throws IOException;
+    throws IOException, InterruptedException;
 
 
   /** Construct a server for a protocol implementation instance. */
   /** Construct a server for a protocol implementation instance. */
   RPC.Server getServer(Class protocol, Object instance, String bindAddress,
   RPC.Server getServer(Class protocol, Object instance, String bindAddress,
                        int port, int numHandlers, boolean verbose,
                        int port, int numHandlers, boolean verbose,
-                       Configuration conf) throws IOException;
+                       Configuration conf, 
+                       SecretManager<? extends TokenIdentifier> secretManager
+                       ) throws IOException;
 
 
 }
 }

+ 258 - 46
src/java/org/apache/hadoop/ipc/Server.java

@@ -32,6 +32,7 @@ import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.nio.channels.CancelledKeyException;
 import java.nio.channels.CancelledKeyException;
+import java.nio.channels.Channels;
 import java.nio.channels.ClosedChannelException;
 import java.nio.channels.ClosedChannelException;
 import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SelectionKey;
@@ -39,7 +40,6 @@ import java.nio.channels.Selector;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.nio.channels.SocketChannel;
 import java.nio.channels.WritableByteChannel;
 import java.nio.channels.WritableByteChannel;
-import java.security.PrivilegedActionException;
 import java.security.PrivilegedExceptionAction;
 import java.security.PrivilegedExceptionAction;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Collections;
@@ -52,15 +52,26 @@ import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 
 
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.ipc.metrics.RpcMetrics;
 import org.apache.hadoop.ipc.metrics.RpcMetrics;
+import org.apache.hadoop.security.AccessControlException;
+import org.apache.hadoop.security.SaslRpcServer;
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
+import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler;
+import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.authorize.AuthorizationException;
 import org.apache.hadoop.security.authorize.AuthorizationException;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.security.token.SecretManager;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.StringUtils;
 
 
@@ -80,7 +91,8 @@ public abstract class Server {
   
   
   // 1 : Introduce ping and server does not throw away RPCs
   // 1 : Introduce ping and server does not throw away RPCs
   // 3 : Introduce the protocol into the RPC connection header
   // 3 : Introduce the protocol into the RPC connection header
-  public static final byte CURRENT_VERSION = 3;
+  // 4 : Introduced SASL security layer
+  public static final byte CURRENT_VERSION = 4;
   
   
   /**
   /**
    * How many calls/handler are allowed in the queue.
    * How many calls/handler are allowed in the queue.
@@ -158,6 +170,7 @@ public abstract class Server {
   protected RpcMetrics  rpcMetrics;
   protected RpcMetrics  rpcMetrics;
   
   
   private Configuration conf;
   private Configuration conf;
+  private SecretManager<TokenIdentifier> secretManager;
 
 
   private int maxQueueSize;
   private int maxQueueSize;
   private int socketSendBufferSize;
   private int socketSendBufferSize;
@@ -431,7 +444,7 @@ public abstract class Server {
       if (count < 0) {
       if (count < 0) {
         if (LOG.isDebugEnabled())
         if (LOG.isDebugEnabled())
           LOG.debug(getName() + ": disconnecting client " + 
           LOG.debug(getName() + ": disconnecting client " + 
-                    c.getHostAddress() + ". Number of active connections: "+
+                    c + ". Number of active connections: "+
                     numConnections);
                     numConnections);
         closeConnection(c);
         closeConnection(c);
         c = null;
         c = null;
@@ -703,8 +716,7 @@ public abstract class Server {
 
 
   /** Reads calls from a connection and queues them for handling. */
   /** Reads calls from a connection and queues them for handling. */
   private class Connection {
   private class Connection {
-    private boolean versionRead = false; //if initial signature and
-                                         //version are read
+    private boolean rpcHeaderRead = false; // if initial rpc header is read
     private boolean headerRead = false;  //if the connection header that
     private boolean headerRead = false;  //if the connection header that
                                          //follows version is read.
                                          //follows version is read.
 
 
@@ -723,6 +735,13 @@ public abstract class Server {
     
     
     ConnectionHeader header = new ConnectionHeader();
     ConnectionHeader header = new ConnectionHeader();
     Class<?> protocol;
     Class<?> protocol;
+    boolean useSasl;
+    SaslServer saslServer;
+    private AuthMethod authMethod;
+    private boolean saslContextEstablished;
+    private ByteBuffer rpcHeaderBuffer;
+    private ByteBuffer unwrappedData;
+    private ByteBuffer unwrappedDataLengthBuffer;
     
     
     UserGroupInformation user = null;
     UserGroupInformation user = null;
 
 
@@ -731,6 +750,10 @@ public abstract class Server {
     private final Call authFailedCall = 
     private final Call authFailedCall = 
       new Call(AUTHROIZATION_FAILED_CALLID, null, null);
       new Call(AUTHROIZATION_FAILED_CALLID, null, null);
     private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
     private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
+    // Fake 'call' for SASL context setup
+    private static final int SASL_CALLID = -33;
+    private final Call saslCall = new Call(SASL_CALLID, null, null);
+    private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
     
     
     public Connection(SelectionKey key, SocketChannel channel, 
     public Connection(SelectionKey key, SocketChannel channel, 
                       long lastContact) {
                       long lastContact) {
@@ -738,6 +761,8 @@ public abstract class Server {
       this.lastContact = lastContact;
       this.lastContact = lastContact;
       this.data = null;
       this.data = null;
       this.dataLengthBuffer = ByteBuffer.allocate(4);
       this.dataLengthBuffer = ByteBuffer.allocate(4);
+      this.unwrappedData = null;
+      this.unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
       this.socket = channel.socket();
       this.socket = channel.socket();
       InetAddress addr = socket.getInetAddress();
       InetAddress addr = socket.getInetAddress();
       if (addr == null) {
       if (addr == null) {
@@ -795,6 +820,92 @@ public abstract class Server {
       return false;
       return false;
     }
     }
 
 
+    private void saslReadAndProcess(byte[] saslToken) throws IOException,
+        InterruptedException {
+      if (!saslContextEstablished) {
+        if (saslServer == null) {
+          switch (authMethod) {
+          case DIGEST:
+            saslServer = Sasl.createSaslServer(AuthMethod.DIGEST
+                .getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM,
+                SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler(
+                    secretManager));
+            break;
+          default:
+            UserGroupInformation current = UserGroupInformation
+                .getCurrentUser();
+            String fullName = current.getUserName();
+            if (LOG.isDebugEnabled())
+              LOG.debug("Kerberos principal name is " + fullName);
+            final String names[] = SaslRpcServer.splitKerberosName(fullName);
+            if (names.length != 3) {
+              throw new IOException(
+                  "Kerberos principal name does NOT have the expected "
+                      + "hostname part: " + fullName);
+            }
+            current.doAs(new PrivilegedExceptionAction<Object>() {
+              @Override
+              public Object run() throws IOException {
+                saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS
+                    .getMechanismName(), names[0], names[1],
+                    SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler());
+                return null;
+              }
+            });
+          }
+          if (saslServer == null)
+            throw new IOException(
+                "Unable to find SASL server implementation for "
+                    + authMethod.getMechanismName());
+          if (LOG.isDebugEnabled())
+            LOG.debug("Created SASL server with mechanism = "
+                + authMethod.getMechanismName());
+        }
+        if (LOG.isDebugEnabled())
+          LOG.debug("Have read input token of size " + saslToken.length
+              + " for processing by saslServer.evaluateResponse()");
+        byte[] replyToken = saslServer.evaluateResponse(saslToken);
+        if (replyToken != null) {
+          if (LOG.isDebugEnabled())
+            LOG.debug("Will send token of size " + replyToken.length
+                + " from saslServer.");
+          saslCall.connection = this;
+          saslResponse.reset();
+          DataOutputStream out = new DataOutputStream(saslResponse);
+          out.writeInt(replyToken.length);
+          out.write(replyToken, 0, replyToken.length);
+          saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray()));
+          responder.doRespond(saslCall);
+        }
+        if (saslServer.isComplete()) {
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("SASL server context established. Negotiated QoP is "
+                + saslServer.getNegotiatedProperty(Sasl.QOP));
+          }
+          user = UserGroupInformation.createRemoteUser(saslServer
+              .getAuthorizationID());
+          LOG.info("SASL server successfully authenticated client: " + user);
+          saslContextEstablished = true;
+        }
+      } else {
+        if (LOG.isDebugEnabled())
+          LOG.debug("Have read input token of size " + saslToken.length
+              + " for processing by saslServer.unwrap()");
+        byte[] plaintextData = saslServer
+            .unwrap(saslToken, 0, saslToken.length);
+        processUnwrappedData(plaintextData);
+      }
+    }
+    
+    private void disposeSasl() {
+      if (saslServer != null) {
+        try {
+          saslServer.dispose();
+        } catch (SaslException ignored) {
+        }
+      }
+    }
+    
     public int readAndProcess() throws IOException, InterruptedException {
     public int readAndProcess() throws IOException, InterruptedException {
       while (true) {
       while (true) {
         /* Read at most one RPC. If the header is not read completely yet
         /* Read at most one RPC. If the header is not read completely yet
@@ -807,14 +918,33 @@ public abstract class Server {
             return count;
             return count;
         }
         }
       
       
-        if (!versionRead) {
+        if (!rpcHeaderRead) {
           //Every connection is expected to send the header.
           //Every connection is expected to send the header.
-          ByteBuffer versionBuffer = ByteBuffer.allocate(1);
-          count = channelRead(channel, versionBuffer);
-          if (count <= 0) {
+          if (rpcHeaderBuffer == null) {
+            rpcHeaderBuffer = ByteBuffer.allocate(2);
+          }
+          count = channelRead(channel, rpcHeaderBuffer);
+          if (count < 0 || rpcHeaderBuffer.remaining() > 0) {
             return count;
             return count;
           }
           }
-          int version = versionBuffer.get(0);
+          int version = rpcHeaderBuffer.get(0);
+          byte[] method = new byte[] {rpcHeaderBuffer.get(1)};
+          authMethod = AuthMethod.read(new DataInputStream(
+              new ByteArrayInputStream(method)));
+          if (authMethod == null) {
+            throw new IOException("Unable to read authentication method");
+          }
+          if (UserGroupInformation.isSecurityEnabled()
+              && authMethod == AuthMethod.SIMPLE) {
+            throw new IOException("Authentication is required");
+          } 
+          if (!UserGroupInformation.isSecurityEnabled()
+              && authMethod != AuthMethod.SIMPLE) {
+            throw new IOException("Authentication is not supported");
+          }
+          if (authMethod != AuthMethod.SIMPLE) {
+            useSasl = true;
+          }
           
           
           dataLengthBuffer.flip();          
           dataLengthBuffer.flip();          
           if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
           if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
@@ -826,7 +956,8 @@ public abstract class Server {
             return -1;
             return -1;
           }
           }
           dataLengthBuffer.clear();
           dataLengthBuffer.clear();
-          versionRead = true;
+          rpcHeaderBuffer = null;
+          rpcHeaderRead = true;
           continue;
           continue;
         }
         }
         
         
@@ -834,12 +965,11 @@ public abstract class Server {
           dataLengthBuffer.flip();
           dataLengthBuffer.flip();
           dataLength = dataLengthBuffer.getInt();
           dataLength = dataLengthBuffer.getInt();
        
        
-          if (dataLength == Client.PING_CALL_ID) {
+          if (!useSasl && dataLength == Client.PING_CALL_ID) {
             dataLengthBuffer.clear();
             dataLengthBuffer.clear();
             return 0;  //ping message
             return 0;  //ping message
           }
           }
           data = ByteBuffer.allocate(dataLength);
           data = ByteBuffer.allocate(dataLength);
-          incRpcCount();  // Increment the rpc count
         }
         }
         
         
         count = channelRead(channel, data);
         count = channelRead(channel, data);
@@ -847,33 +977,14 @@ public abstract class Server {
         if (data.remaining() == 0) {
         if (data.remaining() == 0) {
           dataLengthBuffer.clear();
           dataLengthBuffer.clear();
           data.flip();
           data.flip();
-          if (headerRead) {
-            processData();
-            data = null;
-            return count;
+          boolean isHeaderRead = headerRead;
+          if (useSasl) {
+            saslReadAndProcess(data.array());
           } else {
           } else {
-            processHeader();
-            headerRead = true;
-            data = null;
-            
-            // Authorize the connection
-            try {
-              authorize(user, header);
-              
-              if (LOG.isDebugEnabled()) {
-                LOG.debug("Successfully authorized " + header);
-              }
-            } catch (AuthorizationException ae) {
-              authFailedCall.connection = this;
-              setupResponse(authFailedResponse, authFailedCall, 
-                            Status.FATAL, null, 
-                            ae.getClass().getName(), ae.getMessage());
-              responder.doRespond(authFailedCall);
-              
-              // Close this connection
-              return -1;
-            }
-
+            processOneRpc(data.array());
+          }
+          data = null;
+          if (!isHeaderRead) {
             continue;
             continue;
           }
           }
         } 
         } 
@@ -882,9 +993,9 @@ public abstract class Server {
     }
     }
 
 
     /// Reads the connection header following version
     /// Reads the connection header following version
-    private void processHeader() throws IOException {
+    private void processHeader(byte[] buf) throws IOException {
       DataInputStream in =
       DataInputStream in =
-        new DataInputStream(new ByteArrayInputStream(data.array()));
+        new DataInputStream(new ByteArrayInputStream(buf));
       header.readFields(in);
       header.readFields(in);
       try {
       try {
         String protocolClassName = header.getProtocol();
         String protocolClassName = header.getProtocol();
@@ -895,12 +1006,73 @@ public abstract class Server {
         throw new IOException("Unknown protocol: " + header.getProtocol());
         throw new IOException("Unknown protocol: " + header.getProtocol());
       }
       }
       
       
-      user = header.getUgi();
+      UserGroupInformation protocolUser = header.getUgi();
+      if (!useSasl) {
+        user = protocolUser;
+      } else if (protocolUser != null && !protocolUser.equals(user)) {
+        throw new AccessControlException("Authenticated user (" + user
+            + ") doesn't match what the client claims to be (" + protocolUser
+            + ")");
+      }
+    }
+    
+    private void processUnwrappedData(byte[] inBuf) throws IOException,
+        InterruptedException {
+      ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
+          inBuf));
+      // Read all RPCs contained in the inBuf, even partial ones
+      while (true) {
+        int count = -1;
+        if (unwrappedDataLengthBuffer.remaining() > 0) {
+          count = channelRead(ch, unwrappedDataLengthBuffer);
+          if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0)
+            return;
+        }
+
+        if (unwrappedData == null) {
+          unwrappedDataLengthBuffer.flip();
+          int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
+
+          if (unwrappedDataLength == Client.PING_CALL_ID) {
+            if (LOG.isDebugEnabled())
+              LOG.debug("Received ping message");
+            unwrappedDataLengthBuffer.clear();
+            continue; // ping message
+          }
+          unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
+        }
+
+        count = channelRead(ch, unwrappedData);
+        if (count <= 0 || unwrappedData.remaining() > 0)
+          return;
+
+        if (unwrappedData.remaining() == 0) {
+          unwrappedDataLengthBuffer.clear();
+          unwrappedData.flip();
+          processOneRpc(unwrappedData.array());
+          unwrappedData = null;
+        }
+      }
     }
     }
     
     
-    private void processData() throws  IOException, InterruptedException {
+    private void processOneRpc(byte[] buf) throws IOException,
+        InterruptedException {
+      if (headerRead) {
+        processData(buf);
+      } else {
+        processHeader(buf);
+        headerRead = true;
+        if (!authorizeConnection()) {
+          throw new AccessControlException("Connection from " + this
+              + " for protocol " + header.getProtocol()
+              + " is unauthorized for user " + user);
+        }
+      }
+    }
+    
+    private void processData(byte[] buf) throws  IOException, InterruptedException {
       DataInputStream dis =
       DataInputStream dis =
-        new DataInputStream(new ByteArrayInputStream(data.array()));
+        new DataInputStream(new ByteArrayInputStream(buf));
       int id = dis.readInt();                    // try to read an id
       int id = dis.readInt();                    // try to read an id
         
         
       if (LOG.isDebugEnabled())
       if (LOG.isDebugEnabled())
@@ -911,9 +1083,27 @@ public abstract class Server {
         
         
       Call call = new Call(id, param, this);
       Call call = new Call(id, param, this);
       callQueue.put(call);              // queue the call; maybe blocked here
       callQueue.put(call);              // queue the call; maybe blocked here
+      incRpcCount();  // Increment the rpc count
     }
     }
 
 
+    private boolean authorizeConnection() throws IOException {
+      try {
+        authorize(user, header);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Successfully authorized " + header);
+        }
+      } catch (AuthorizationException ae) {
+        authFailedCall.connection = this;
+        setupResponse(authFailedResponse, authFailedCall, Status.FATAL, null,
+            ae.getClass().getName(), ae.getMessage());
+        responder.doRespond(authFailedCall);
+        return false;
+      }
+      return true;
+    }
+    
     private synchronized void close() throws IOException {
     private synchronized void close() throws IOException {
+      disposeSasl();
       data = null;
       data = null;
       dataLengthBuffer = null;
       dataLengthBuffer = null;
       if (!channel.isOpen())
       if (!channel.isOpen())
@@ -1011,16 +1201,17 @@ public abstract class Server {
                   Configuration conf)
                   Configuration conf)
     throws IOException 
     throws IOException 
   {
   {
-    this(bindAddress, port, paramClass, handlerCount,  conf, Integer.toString(port));
+    this(bindAddress, port, paramClass, handlerCount,  conf, Integer.toString(port), null);
   }
   }
   /** Constructs a server listening on the named port and address.  Parameters passed must
   /** Constructs a server listening on the named port and address.  Parameters passed must
    * be of the named class.  The <code>handlerCount</handlerCount> determines
    * be of the named class.  The <code>handlerCount</handlerCount> determines
    * the number of handler threads that will be used to process calls.
    * the number of handler threads that will be used to process calls.
    * 
    * 
    */
    */
+  @SuppressWarnings("unchecked")
   protected Server(String bindAddress, int port, 
   protected Server(String bindAddress, int port, 
                   Class<? extends Writable> paramClass, int handlerCount, 
                   Class<? extends Writable> paramClass, int handlerCount, 
-                  Configuration conf, String serverName) 
+                  Configuration conf, String serverName, SecretManager<? extends TokenIdentifier> secretManager) 
     throws IOException {
     throws IOException {
     this.bindAddress = bindAddress;
     this.bindAddress = bindAddress;
     this.conf = conf;
     this.conf = conf;
@@ -1033,6 +1224,7 @@ public abstract class Server {
     this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000);
     this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000);
     this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10);
     this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10);
     this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000);
     this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000);
+    this.secretManager = (SecretManager<TokenIdentifier>) secretManager;
     this.authorize = 
     this.authorize = 
       conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG, 
       conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG, 
                       false);
                       false);
@@ -1086,9 +1278,29 @@ public abstract class Server {
       WritableUtils.writeString(out, errorClass);
       WritableUtils.writeString(out, errorClass);
       WritableUtils.writeString(out, error);
       WritableUtils.writeString(out, error);
     }
     }
+    wrapWithSasl(response, call);
     call.setResponse(ByteBuffer.wrap(response.toByteArray()));
     call.setResponse(ByteBuffer.wrap(response.toByteArray()));
   }
   }
   
   
+  private void wrapWithSasl(ByteArrayOutputStream response, Call call)
+      throws IOException {
+    if (call.connection.useSasl) {
+      byte[] token = response.toByteArray();
+      // synchronization may be needed since there can be multiple Handler
+      // threads using saslServer to wrap responses.
+      synchronized (call.connection.saslServer) {
+        token = call.connection.saslServer.wrap(token, 0, token.length);
+      }
+      if (LOG.isDebugEnabled())
+        LOG.debug("Adding saslServer wrapped token of size " + token.length
+            + " as call response.");
+      response.reset();
+      DataOutputStream saslOut = new DataOutputStream(response);
+      saslOut.writeInt(token.length);
+      saslOut.write(token, 0, token.length);
+    }
+  }
+  
   Configuration getConf() {
   Configuration getConf() {
     return conf;
     return conf;
   }
   }

+ 13 - 6
src/java/org/apache/hadoop/ipc/WritableRpcEngine.java

@@ -36,6 +36,8 @@ import org.apache.commons.logging.*;
 import org.apache.hadoop.io.*;
 import org.apache.hadoop.io.*;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
 import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
 
 
@@ -246,7 +248,7 @@ class WritableRpcEngine implements RpcEngine {
   public Object[] call(Method method, Object[][] params,
   public Object[] call(Method method, Object[][] params,
                        InetSocketAddress[] addrs, 
                        InetSocketAddress[] addrs, 
                        UserGroupInformation ticket, Configuration conf)
                        UserGroupInformation ticket, Configuration conf)
-    throws IOException {
+    throws IOException, InterruptedException {
 
 
     Invocation[] invocations = new Invocation[params.length];
     Invocation[] invocations = new Invocation[params.length];
     for (int i = 0; i < params.length; i++)
     for (int i = 0; i < params.length; i++)
@@ -276,9 +278,11 @@ class WritableRpcEngine implements RpcEngine {
    * port and address. */
    * port and address. */
   public Server getServer(Class protocol,
   public Server getServer(Class protocol,
                           Object instance, String bindAddress, int port,
                           Object instance, String bindAddress, int port,
-                          int numHandlers, boolean verbose, Configuration conf) 
+                          int numHandlers, boolean verbose, Configuration conf,
+                      SecretManager<? extends TokenIdentifier> secretManager) 
     throws IOException {
     throws IOException {
-    return new Server(instance, conf, bindAddress, port, numHandlers, verbose);
+    return new Server(instance, conf, bindAddress, port, numHandlers, 
+        verbose, secretManager);
   }
   }
 
 
   /** An RPC Server. */
   /** An RPC Server. */
@@ -294,7 +298,7 @@ class WritableRpcEngine implements RpcEngine {
      */
      */
     public Server(Object instance, Configuration conf, String bindAddress, int port) 
     public Server(Object instance, Configuration conf, String bindAddress, int port) 
       throws IOException {
       throws IOException {
-      this(instance, conf,  bindAddress, port, 1, false);
+      this(instance, conf,  bindAddress, port, 1, false, null);
     }
     }
     
     
     private static String classNameBase(String className) {
     private static String classNameBase(String className) {
@@ -314,8 +318,11 @@ class WritableRpcEngine implements RpcEngine {
      * @param verbose whether each call should be logged
      * @param verbose whether each call should be logged
      */
      */
     public Server(Object instance, Configuration conf, String bindAddress,  int port,
     public Server(Object instance, Configuration conf, String bindAddress,  int port,
-                  int numHandlers, boolean verbose) throws IOException {
-      super(bindAddress, port, Invocation.class, numHandlers, conf, classNameBase(instance.getClass().getName()));
+                  int numHandlers, boolean verbose, 
+                  SecretManager<? extends TokenIdentifier> secretManager) 
+        throws IOException {
+      super(bindAddress, port, Invocation.class, numHandlers, conf, 
+          classNameBase(instance.getClass().getName()), secretManager);
       this.instance = instance;
       this.instance = instance;
       this.verbose = verbose;
       this.verbose = verbose;
     }
     }

+ 31 - 0
src/java/org/apache/hadoop/security/KerberosInfo.java

@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security;
+
+import java.lang.annotation.*;
+
+/**
+ * Indicates Kerberos related information to be used
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.TYPE)
+public @interface KerberosInfo {
+  /** Key for getting server's Kerberos principal name from Configuration */
+  String value();
+}

+ 321 - 0
src/java/org/apache/hadoop/security/SaslInputStream.java

@@ -0,0 +1,321 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security;
+
+import java.io.DataInputStream;
+import java.io.EOFException;
+import java.io.InputStream;
+import java.io.IOException;
+
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+/**
+ * A SaslInputStream is composed of an InputStream and a SaslServer (or
+ * SaslClient) so that read() methods return data that are read in from the
+ * underlying InputStream but have been additionally processed by the SaslServer
+ * (or SaslClient) object. The SaslServer (or SaslClient) object must be fully
+ * initialized before being used by a SaslInputStream.
+ */
+public class SaslInputStream extends InputStream {
+  public static final Log LOG = LogFactory.getLog(SaslInputStream.class);
+
+  private final DataInputStream inStream;
+  /*
+   * data read from the underlying input stream before being processed by SASL
+   */
+  private byte[] saslToken;
+  private final SaslClient saslClient;
+  private final SaslServer saslServer;
+  private byte[] lengthBuf = new byte[4];
+  /*
+   * buffer holding data that have been processed by SASL, but have not been
+   * read out
+   */
+  private byte[] obuffer;
+  // position of the next "new" byte
+  private int ostart = 0;
+  // position of the last "new" byte
+  private int ofinish = 0;
+
+  private static int unsignedBytesToInt(byte[] buf) {
+    if (buf.length != 4) {
+      throw new IllegalArgumentException(
+          "Cannot handle byte array other than 4 bytes");
+    }
+    int result = 0;
+    for (int i = 0; i < 4; i++) {
+      result <<= 8;
+      result |= ((int) buf[i] & 0xff);
+    }
+    return result;
+  }
+
+  /**
+   * Read more data and get them processed <br>
+   * Entry condition: ostart = ofinish <br>
+   * Exit condition: ostart <= ofinish <br>
+   * 
+   * return (ofinish-ostart) (we have this many bytes for you), 0 (no data now,
+   * but could have more later), or -1 (absolutely no more data)
+   */
+  private int readMoreData() throws IOException {
+    try {
+      inStream.readFully(lengthBuf);
+      int length = unsignedBytesToInt(lengthBuf);
+      if (LOG.isDebugEnabled())
+        LOG.debug("Actual length is " + length);
+      saslToken = new byte[length];
+      inStream.readFully(saslToken);
+    } catch (EOFException e) {
+      return -1;
+    }
+    try {
+      if (saslServer != null) { // using saslServer
+        obuffer = saslServer.unwrap(saslToken, 0, saslToken.length);
+      } else { // using saslClient
+        obuffer = saslClient.unwrap(saslToken, 0, saslToken.length);
+      }
+    } catch (SaslException se) {
+      try {
+        disposeSasl();
+      } catch (SaslException ignored) {
+      }
+      throw se;
+    }
+    ostart = 0;
+    if (obuffer == null)
+      ofinish = 0;
+    else
+      ofinish = obuffer.length;
+    return ofinish;
+  }
+
+  /**
+   * Disposes of any system resources or security-sensitive information Sasl
+   * might be using.
+   * 
+   * @exception SaslException
+   *              if a SASL error occurs.
+   */
+  private void disposeSasl() throws SaslException {
+    if (saslClient != null) {
+      saslClient.dispose();
+    }
+    if (saslServer != null) {
+      saslServer.dispose();
+    }
+  }
+
+  /**
+   * Constructs a SASLInputStream from an InputStream and a SaslServer <br>
+   * Note: if the specified InputStream or SaslServer is null, a
+   * NullPointerException may be thrown later when they are used.
+   * 
+   * @param inStream
+   *          the InputStream to be processed
+   * @param saslServer
+   *          an initialized SaslServer object
+   */
+  public SaslInputStream(InputStream inStream, SaslServer saslServer) {
+    this.inStream = new DataInputStream(inStream);
+    this.saslServer = saslServer;
+    this.saslClient = null;
+  }
+
+  /**
+   * Constructs a SASLInputStream from an InputStream and a SaslClient <br>
+   * Note: if the specified InputStream or SaslClient is null, a
+   * NullPointerException may be thrown later when they are used.
+   * 
+   * @param inStream
+   *          the InputStream to be processed
+   * @param saslClient
+   *          an initialized SaslClient object
+   */
+  public SaslInputStream(InputStream inStream, SaslClient saslClient) {
+    this.inStream = new DataInputStream(inStream);
+    this.saslServer = null;
+    this.saslClient = saslClient;
+  }
+
+  /**
+   * Reads the next byte of data from this input stream. The value byte is
+   * returned as an <code>int</code> in the range <code>0</code> to
+   * <code>255</code>. If no byte is available because the end of the stream has
+   * been reached, the value <code>-1</code> is returned. This method blocks
+   * until input data is available, the end of the stream is detected, or an
+   * exception is thrown.
+   * <p>
+   * 
+   * @return the next byte of data, or <code>-1</code> if the end of the stream
+   *         is reached.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public int read() throws IOException {
+    if (ostart >= ofinish) {
+      // we loop for new data as we are blocking
+      int i = 0;
+      while (i == 0)
+        i = readMoreData();
+      if (i == -1)
+        return -1;
+    }
+    return ((int) obuffer[ostart++] & 0xff);
+  }
+
+  /**
+   * Reads up to <code>b.length</code> bytes of data from this input stream into
+   * an array of bytes.
+   * <p>
+   * The <code>read</code> method of <code>InputStream</code> calls the
+   * <code>read</code> method of three arguments with the arguments
+   * <code>b</code>, <code>0</code>, and <code>b.length</code>.
+   * 
+   * @param b
+   *          the buffer into which the data is read.
+   * @return the total number of bytes read into the buffer, or <code>-1</code>
+   *         is there is no more data because the end of the stream has been
+   *         reached.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public int read(byte[] b) throws IOException {
+    return read(b, 0, b.length);
+  }
+
+  /**
+   * Reads up to <code>len</code> bytes of data from this input stream into an
+   * array of bytes. This method blocks until some input is available. If the
+   * first argument is <code>null,</code> up to <code>len</code> bytes are read
+   * and discarded.
+   * 
+   * @param b
+   *          the buffer into which the data is read.
+   * @param off
+   *          the start offset of the data.
+   * @param len
+   *          the maximum number of bytes read.
+   * @return the total number of bytes read into the buffer, or <code>-1</code>
+   *         if there is no more data because the end of the stream has been
+   *         reached.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public int read(byte[] b, int off, int len) throws IOException {
+    if (ostart >= ofinish) {
+      // we loop for new data as we are blocking
+      int i = 0;
+      while (i == 0)
+        i = readMoreData();
+      if (i == -1)
+        return -1;
+    }
+    if (len <= 0) {
+      return 0;
+    }
+    int available = ofinish - ostart;
+    if (len < available)
+      available = len;
+    if (b != null) {
+      System.arraycopy(obuffer, ostart, b, off, available);
+    }
+    ostart = ostart + available;
+    return available;
+  }
+
+  /**
+   * Skips <code>n</code> bytes of input from the bytes that can be read from
+   * this input stream without blocking.
+   * 
+   * <p>
+   * Fewer bytes than requested might be skipped. The actual number of bytes
+   * skipped is equal to <code>n</code> or the result of a call to
+   * {@link #available() <code>available</code>}, whichever is smaller. If
+   * <code>n</code> is less than zero, no bytes are skipped.
+   * 
+   * <p>
+   * The actual number of bytes skipped is returned.
+   * 
+   * @param n
+   *          the number of bytes to be skipped.
+   * @return the actual number of bytes skipped.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public long skip(long n) throws IOException {
+    int available = ofinish - ostart;
+    if (n > available) {
+      n = available;
+    }
+    if (n < 0) {
+      return 0;
+    }
+    ostart += n;
+    return n;
+  }
+
+  /**
+   * Returns the number of bytes that can be read from this input stream without
+   * blocking. The <code>available</code> method of <code>InputStream</code>
+   * returns <code>0</code>. This method <B>should</B> be overridden by
+   * subclasses.
+   * 
+   * @return the number of bytes that can be read from this input stream without
+   *         blocking.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public int available() throws IOException {
+    return (ofinish - ostart);
+  }
+
+  /**
+   * Closes this input stream and releases any system resources associated with
+   * the stream.
+   * <p>
+   * The <code>close</code> method of <code>SASLInputStream</code> calls the
+   * <code>close</code> method of its underlying input stream.
+   * 
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void close() throws IOException {
+    disposeSasl();
+    ostart = 0;
+    ofinish = 0;
+    inStream.close();
+  }
+
+  /**
+   * Tests if this input stream supports the <code>mark</code> and
+   * <code>reset</code> methods, which it does not.
+   * 
+   * @return <code>false</code>, since this class does not support the
+   *         <code>mark</code> and <code>reset</code> methods.
+   */
+  public boolean markSupported() {
+    return false;
+  }
+}

+ 181 - 0
src/java/org/apache/hadoop/security/SaslOutputStream.java

@@ -0,0 +1,181 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+/**
+ * A SaslOutputStream is composed of an OutputStream and a SaslServer (or
+ * SaslClient) so that write() methods first process the data before writing
+ * them out to the underlying OutputStream. The SaslServer (or SaslClient)
+ * object must be fully initialized before being used by a SaslOutputStream.
+ */
+public class SaslOutputStream extends OutputStream {
+
+  private final DataOutputStream outStream;
+  // processed data ready to be written out
+  private byte[] saslToken;
+
+  private final SaslClient saslClient;
+  private final SaslServer saslServer;
+  // buffer holding one byte of incoming data
+  private final byte[] ibuffer = new byte[1];
+
+  /**
+   * Constructs a SASLOutputStream from an OutputStream and a SaslServer <br>
+   * Note: if the specified OutputStream or SaslServer is null, a
+   * NullPointerException may be thrown later when they are used.
+   * 
+   * @param outStream
+   *          the OutputStream to be processed
+   * @param saslServer
+   *          an initialized SaslServer object
+   */
+  public SaslOutputStream(OutputStream outStream, SaslServer saslServer) {
+    this.outStream = new DataOutputStream(outStream);
+    this.saslServer = saslServer;
+    this.saslClient = null;
+  }
+
+  /**
+   * Constructs a SASLOutputStream from an OutputStream and a SaslClient <br>
+   * Note: if the specified OutputStream or SaslClient is null, a
+   * NullPointerException may be thrown later when they are used.
+   * 
+   * @param outStream
+   *          the OutputStream to be processed
+   * @param saslClient
+   *          an initialized SaslClient object
+   */
+  public SaslOutputStream(OutputStream outStream, SaslClient saslClient) {
+    this.outStream = new DataOutputStream(outStream);
+    this.saslServer = null;
+    this.saslClient = saslClient;
+  }
+
+  /**
+   * Disposes of any system resources or security-sensitive information Sasl
+   * might be using.
+   * 
+   * @exception SaslException
+   *              if a SASL error occurs.
+   */
+  private void disposeSasl() throws SaslException {
+    if (saslClient != null) {
+      saslClient.dispose();
+    }
+    if (saslServer != null) {
+      saslServer.dispose();
+    }
+  }
+
+  /**
+   * Writes the specified byte to this output stream.
+   * 
+   * @param b
+   *          the <code>byte</code>.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void write(int b) throws IOException {
+    ibuffer[0] = (byte) b;
+    write(ibuffer, 0, 1);
+  }
+
+  /**
+   * Writes <code>b.length</code> bytes from the specified byte array to this
+   * output stream.
+   * <p>
+   * The <code>write</code> method of <code>SASLOutputStream</code> calls the
+   * <code>write</code> method of three arguments with the three arguments
+   * <code>b</code>, <code>0</code>, and <code>b.length</code>.
+   * 
+   * @param b
+   *          the data.
+   * @exception NullPointerException
+   *              if <code>b</code> is null.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void write(byte[] b) throws IOException {
+    write(b, 0, b.length);
+  }
+
+  /**
+   * Writes <code>len</code> bytes from the specified byte array starting at
+   * offset <code>off</code> to this output stream.
+   * 
+   * @param inBuf
+   *          the data.
+   * @param off
+   *          the start offset in the data.
+   * @param len
+   *          the number of bytes to write.
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void write(byte[] inBuf, int off, int len) throws IOException {
+    try {
+      if (saslServer != null) { // using saslServer
+        saslToken = saslServer.wrap(inBuf, off, len);
+      } else { // using saslClient
+        saslToken = saslClient.wrap(inBuf, off, len);
+      }
+    } catch (SaslException se) {
+      try {
+        disposeSasl();
+      } catch (SaslException ignored) {
+      }
+      throw se;
+    }
+    if (saslToken != null) {
+      outStream.writeInt(saslToken.length);
+      outStream.write(saslToken, 0, saslToken.length);
+      saslToken = null;
+    }
+  }
+
+  /**
+   * Flushes this output stream
+   * 
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void flush() throws IOException {
+    outStream.flush();
+  }
+
+  /**
+   * Closes this output stream and releases any system resources associated with
+   * this stream.
+   * 
+   * @exception IOException
+   *              if an I/O error occurs.
+   */
+  public void close() throws IOException {
+    disposeSasl();
+    outStream.close();
+  }
+}

+ 249 - 0
src/java/org/apache/hadoop/security/SaslRpcClient.java

@@ -0,0 +1,249 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslClient;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.security.token.TokenIdentifier;
+
+/**
+ * A utility class that encapsulates SASL logic for RPC client
+ */
+public class SaslRpcClient {
+  public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
+
+  private final SaslClient saslClient;
+
+  /**
+   * Create a SaslRpcClient for an authentication method
+   * 
+   * @param method
+   *          the requested authentication method
+   * @param token
+   *          token to use if needed by the authentication method
+   */
+  public SaslRpcClient(AuthMethod method,
+      Token<? extends TokenIdentifier> token, String serverPrincipal)
+      throws IOException {
+    switch (method) {
+    case DIGEST:
+      if (LOG.isDebugEnabled())
+        LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
+            + " client to authenticate to service at " + token.getService());
+      saslClient = Sasl.createSaslClient(new String[] { AuthMethod.DIGEST
+          .getMechanismName() }, null, null, SaslRpcServer.SASL_DEFAULT_REALM,
+          SaslRpcServer.SASL_PROPS, new SaslClientCallbackHandler(token));
+      break;
+    case KERBEROS:
+      if (LOG.isDebugEnabled()) {
+        LOG
+            .debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
+                + " client. Server's Kerberos principal name is "
+                + serverPrincipal);
+      }
+      if (serverPrincipal == null || serverPrincipal.length() == 0) {
+        throw new IOException(
+            "Failed to specify server's Kerberos principal name");
+      }
+      String names[] = SaslRpcServer.splitKerberosName(serverPrincipal);
+      if (names.length != 3) {
+        throw new IOException(
+          "Kerberos principal name does NOT have the expected hostname part: "
+                + serverPrincipal);
+      }
+      saslClient = Sasl.createSaslClient(new String[] { AuthMethod.KERBEROS
+          .getMechanismName() }, null, names[0], names[1],
+          SaslRpcServer.SASL_PROPS, null);
+      break;
+    default:
+      throw new IOException("Unknown authentication method " + method);
+    }
+    if (saslClient == null)
+      throw new IOException("Unable to find SASL client implementation");
+  }
+
+  /**
+   * Do client side SASL authentication with server via the given InputStream
+   * and OutputStream
+   * 
+   * @param inS
+   *          InputStream to use
+   * @param outS
+   *          OutputStream to use
+   * @throws IOException
+   */
+  public void saslConnect(InputStream inS, OutputStream outS)
+      throws IOException {
+    DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
+    DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(
+        outS));
+
+    try {
+      byte[] saslToken = new byte[0];
+      if (saslClient.hasInitialResponse())
+        saslToken = saslClient.evaluateChallenge(saslToken);
+      if (saslToken != null) {
+        outStream.writeInt(saslToken.length);
+        outStream.write(saslToken, 0, saslToken.length);
+        outStream.flush();
+        if (LOG.isDebugEnabled())
+          LOG.debug("Have sent token of size " + saslToken.length
+              + " from initSASLContext.");
+      }
+      if (!saslClient.isComplete()) {
+        saslToken = new byte[inStream.readInt()];
+        if (LOG.isDebugEnabled())
+          LOG.debug("Will read input token of size " + saslToken.length
+              + " for processing by initSASLContext");
+        inStream.readFully(saslToken);
+      }
+
+      while (!saslClient.isComplete()) {
+        saslToken = saslClient.evaluateChallenge(saslToken);
+        if (saslToken != null) {
+          if (LOG.isDebugEnabled())
+            LOG.debug("Will send token of size " + saslToken.length
+                + " from initSASLContext.");
+          outStream.writeInt(saslToken.length);
+          outStream.write(saslToken, 0, saslToken.length);
+          outStream.flush();
+        }
+        if (!saslClient.isComplete()) {
+          saslToken = new byte[inStream.readInt()];
+          if (LOG.isDebugEnabled())
+            LOG.debug("Will read input token of size " + saslToken.length
+                + " for processing by initSASLContext");
+          inStream.readFully(saslToken);
+        }
+      }
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("SASL client context established. Negotiated QoP: "
+            + saslClient.getNegotiatedProperty(Sasl.QOP));
+      }
+    } catch (IOException e) {
+      saslClient.dispose();
+      throw e;
+    }
+  }
+
+  /**
+   * Get a SASL wrapped InputStream. Can be called only after saslConnect() has
+   * been called.
+   * 
+   * @param in
+   *          the InputStream to wrap
+   * @return a SASL wrapped InputStream
+   * @throws IOException
+   */
+  public InputStream getInputStream(InputStream in) throws IOException {
+    if (!saslClient.isComplete()) {
+      throw new IOException("Sasl authentication exchange hasn't completed yet");
+    }
+    return new SaslInputStream(in, saslClient);
+  }
+
+  /**
+   * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has
+   * been called.
+   * 
+   * @param out
+   *          the OutputStream to wrap
+   * @return a SASL wrapped OutputStream
+   * @throws IOException
+   */
+  public OutputStream getOutputStream(OutputStream out) throws IOException {
+    if (!saslClient.isComplete()) {
+      throw new IOException("Sasl authentication exchange hasn't completed yet");
+    }
+    return new SaslOutputStream(out, saslClient);
+  }
+
+  /** Release resources used by wrapped saslClient */
+  public void dispose() throws SaslException {
+    saslClient.dispose();
+  }
+
+  private static class SaslClientCallbackHandler implements CallbackHandler {
+    private final String userName;
+    private final char[] userPassword;
+
+    public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) {
+      this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier());
+      this.userPassword = SaslRpcServer.encodePassword(token.getPassword());
+    }
+
+    public void handle(Callback[] callbacks)
+        throws UnsupportedCallbackException {
+      NameCallback nc = null;
+      PasswordCallback pc = null;
+      RealmCallback rc = null;
+      for (Callback callback : callbacks) {
+        if (callback instanceof RealmChoiceCallback) {
+          continue;
+        } else if (callback instanceof NameCallback) {
+          nc = (NameCallback) callback;
+        } else if (callback instanceof PasswordCallback) {
+          pc = (PasswordCallback) callback;
+        } else if (callback instanceof RealmCallback) {
+          rc = (RealmCallback) callback;
+        } else {
+          throw new UnsupportedCallbackException(callback,
+              "Unrecognized SASL client callback");
+        }
+      }
+      if (nc != null) {
+        if (LOG.isDebugEnabled())
+          LOG.debug("SASL client callback: setting username: " + userName);
+        nc.setName(userName);
+      }
+      if (pc != null) {
+        if (LOG.isDebugEnabled())
+          LOG.debug("SASL client callback: setting userPassword");
+        pc.setPassword(userPassword);
+      }
+      if (rc != null) {
+        if (LOG.isDebugEnabled())
+          LOG.debug("SASL client callback: setting realm: "
+              + rc.getDefaultText());
+        rc.setText(rc.getDefaultText());
+      }
+    }
+  }
+}

+ 218 - 0
src/java/org/apache/hadoop/security/SaslRpcServer.java

@@ -0,0 +1,218 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security;
+
+import java.io.ByteArrayInputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.TreeMap;
+import java.util.Map;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+
+import org.apache.commons.codec.binary.Base64;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.TokenIdentifier;
+
+/**
+ * A utility class for dealing with SASL on RPC server
+ */
+public class SaslRpcServer {
+  public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
+  public static final String SASL_DEFAULT_REALM = "default";
+  public static final Map<String, String> SASL_PROPS = 
+      new TreeMap<String, String>();
+  static {
+    // Request authentication plus integrity protection
+    SASL_PROPS.put(Sasl.QOP, "auth-int");
+    // Request mutual authentication
+    SASL_PROPS.put(Sasl.SERVER_AUTH, "true");
+  }
+
+  static String encodeIdentifier(byte[] identifier) {
+    return new String(Base64.encodeBase64(identifier));
+  }
+
+  static byte[] decodeIdentifier(String identifier) {
+    return Base64.decodeBase64(identifier.getBytes());
+  }
+
+  static char[] encodePassword(byte[] password) {
+    return new String(Base64.encodeBase64(password)).toCharArray();
+  }
+
+  /** Splitting fully qualified Kerberos name into parts */
+  public static String[] splitKerberosName(String fullName) {
+    return fullName.split("[/@]");
+  }
+
+  /** Authentication method */
+  public static enum AuthMethod {
+    SIMPLE((byte) 80, ""), // no authentication
+    KERBEROS((byte) 81, "GSSAPI"), // SASL Kerberos authentication
+    DIGEST((byte) 82, "DIGEST-MD5"); // SASL DIGEST-MD5 authentication
+
+    /** The code for this method. */
+    public final byte code;
+    public final String mechanismName;
+
+    private AuthMethod(byte code, String mechanismName) {
+      this.code = code;
+      this.mechanismName = mechanismName;
+    }
+
+    private static final int FIRST_CODE = values()[0].code;
+
+    /** Return the object represented by the code. */
+    private static AuthMethod valueOf(byte code) {
+      final int i = (code & 0xff) - FIRST_CODE;
+      return i < 0 || i >= values().length ? null : values()[i];
+    }
+
+    /** Return the SASL mechanism name */
+    public String getMechanismName() {
+      return mechanismName;
+    }
+
+    /** Read from in */
+    public static AuthMethod read(DataInput in) throws IOException {
+      return valueOf(in.readByte());
+    }
+
+    /** Write to out */
+    public void write(DataOutput out) throws IOException {
+      out.write(code);
+    }
+  };
+
+  /** CallbackHandler for SASL DIGEST-MD5 mechanism */
+  public static class SaslDigestCallbackHandler implements CallbackHandler {
+    private SecretManager<TokenIdentifier> secretManager;
+
+    public SaslDigestCallbackHandler(
+        SecretManager<TokenIdentifier> secretManager) {
+      this.secretManager = secretManager;
+    }
+
+    private TokenIdentifier getIdentifier(String id) throws IOException {
+      byte[] tokenId = decodeIdentifier(id);
+      TokenIdentifier tokenIdentifier = secretManager.createIdentifier();
+      tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
+          tokenId)));
+      return tokenIdentifier;
+    }
+
+    private char[] getPassword(TokenIdentifier tokenid) throws IOException {
+      return encodePassword(secretManager.retrievePassword(tokenid));
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void handle(Callback[] callbacks) throws IOException,
+        UnsupportedCallbackException {
+      NameCallback nc = null;
+      PasswordCallback pc = null;
+      AuthorizeCallback ac = null;
+      for (Callback callback : callbacks) {
+        if (callback instanceof AuthorizeCallback) {
+          ac = (AuthorizeCallback) callback;
+        } else if (callback instanceof NameCallback) {
+          nc = (NameCallback) callback;
+        } else if (callback instanceof PasswordCallback) {
+          pc = (PasswordCallback) callback;
+        } else if (callback instanceof RealmCallback) {
+          continue; // realm is ignored
+        } else {
+          throw new UnsupportedCallbackException(callback,
+              "Unrecognized SASL DIGEST-MD5 Callback");
+        }
+      }
+      if (pc != null) {
+        TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName());
+        char[] password = getPassword(tokenIdentifier);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("SASL server DIGEST-MD5 callback: setting password "
+              + "for client: " + tokenIdentifier.getUsername());
+        }
+        pc.setPassword(password);
+      }
+      if (ac != null) {
+        String authid = ac.getAuthenticationID();
+        String authzid = ac.getAuthorizationID();
+        if (authid.equals(authzid)) {
+          ac.setAuthorized(true);
+        } else {
+          ac.setAuthorized(false);
+        }
+        if (ac.isAuthorized()) {
+          String username = getIdentifier(authzid).getUsername().toString();
+          if (LOG.isDebugEnabled())
+            LOG.debug("SASL server DIGEST-MD5 callback: setting "
+                + "canonicalized client ID: " + username);
+          ac.setAuthorizedID(username);
+        }
+      }
+    }
+  }
+
+  /** CallbackHandler for SASL GSSAPI Kerberos mechanism */
+  public static class SaslGssCallbackHandler implements CallbackHandler {
+
+    /** {@inheritDoc} */
+    @Override
+    public void handle(Callback[] callbacks) throws IOException,
+        UnsupportedCallbackException {
+      AuthorizeCallback ac = null;
+      for (Callback callback : callbacks) {
+        if (callback instanceof AuthorizeCallback) {
+          ac = (AuthorizeCallback) callback;
+        } else {
+          throw new UnsupportedCallbackException(callback,
+              "Unrecognized SASL GSSAPI Callback");
+        }
+      }
+      if (ac != null) {
+        String authid = ac.getAuthenticationID();
+        String authzid = ac.getAuthorizationID();
+        if (authid.equals(authzid)) {
+          ac.setAuthorized(true);
+        } else {
+          ac.setAuthorized(false);
+        }
+        if (ac.isAuthorized()) {
+          if (LOG.isDebugEnabled())
+            LOG.debug("SASL server GSSAPI callback: setting "
+                + "canonicalized client ID: " + authzid);
+          ac.setAuthorizedID(authzid);
+        }
+      }
+    }
+  }
+}

+ 6 - 0
src/java/org/apache/hadoop/security/token/SecretManager.java

@@ -61,6 +61,12 @@ public abstract class SecretManager<T extends TokenIdentifier> {
    */
    */
   public abstract byte[] retrievePassword(T identifier) throws InvalidToken;
   public abstract byte[] retrievePassword(T identifier) throws InvalidToken;
   
   
+  /**
+   * Create an empty token identifier.
+   * @return the newly created empty token identifier
+   */
+  public abstract T createIdentifier();
+  
   /**
   /**
    * The name of the hashing algorithm.
    * The name of the hashing algorithm.
    */
    */

+ 6 - 0
src/java/org/apache/hadoop/security/token/TokenIdentifier.java

@@ -35,6 +35,12 @@ public abstract class TokenIdentifier implements Writable {
    * @return the kind of the token
    * @return the kind of the token
    */
    */
   public abstract Text getKind();
   public abstract Text getKind();
+  
+  /**
+   * Get the username encoded in the token identifier
+   * @return the username
+   */
+  public abstract Text getUsername();
 
 
   /**
   /**
    * Get the bytes for the token identifier
    * Get the bytes for the token identifier

+ 31 - 0
src/java/org/apache/hadoop/security/token/TokenInfo.java

@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security.token;
+
+import java.lang.annotation.*;
+
+/**
+ * Indicates Token related information to be used
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.TYPE)
+public @interface TokenInfo {
+  /** The type of TokenSelector to be used */
+  Class<? extends TokenSelector<? extends TokenIdentifier>> value();
+}

+ 34 - 0
src/java/org/apache/hadoop/security/token/TokenSelector.java

@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.security.token;
+
+import java.util.Collection;
+
+import org.apache.hadoop.io.Text;
+
+/**
+ * Select token of type T from tokens for use with named service
+ * 
+ * @param <T>
+ *          T extends TokenIdentifier
+ */
+public interface TokenSelector<T extends TokenIdentifier> {
+  Token<T> selectToken(Text service,
+      Collection<Token<? extends TokenIdentifier>> tokens);
+}

+ 6 - 0
src/test/core-site.xml

@@ -63,4 +63,10 @@
   <description>The name of the s3n file system for testing.</description>
   <description>The name of the s3n file system for testing.</description>
 </property>
 </property>
 
 
+<!-- Turn security off for tests by default -->
+<property>
+  <name>hadoop.security.authentication</name>
+  <value>simple</value>
+</property>
+
 </configuration>
 </configuration>

+ 3 - 3
src/test/core/org/apache/hadoop/ipc/TestRPC.java

@@ -68,7 +68,7 @@ public class TestRPC extends TestCase {
     int[] exchange(int[] values) throws IOException;
     int[] exchange(int[] values) throws IOException;
   }
   }
 
 
-  public class TestImpl implements TestProtocol {
+  public static class TestImpl implements TestProtocol {
     int fastPingCounter = 0;
     int fastPingCounter = 0;
     
     
     public long getProtocolVersion(String protocol, long clientVersion) {
     public long getProtocolVersion(String protocol, long clientVersion) {
@@ -189,7 +189,7 @@ public class TestRPC extends TestCase {
     System.out.println("Testing Slow RPC");
     System.out.println("Testing Slow RPC");
     // create a server with two handlers
     // create a server with two handlers
     Server server = RPC.getServer(TestProtocol.class,
     Server server = RPC.getServer(TestProtocol.class,
-                                  new TestImpl(), ADDRESS, 0, 2, false, conf);
+                                  new TestImpl(), ADDRESS, 0, 2, false, conf, null);
     TestProtocol proxy = null;
     TestProtocol proxy = null;
     
     
     try {
     try {
@@ -339,7 +339,7 @@ public class TestRPC extends TestCase {
     ServiceAuthorizationManager.refresh(conf, new TestPolicyProvider());
     ServiceAuthorizationManager.refresh(conf, new TestPolicyProvider());
     
     
     Server server = RPC.getServer(TestProtocol.class,
     Server server = RPC.getServer(TestProtocol.class,
-                                  new TestImpl(), ADDRESS, 0, 5, true, conf);
+                                  new TestImpl(), ADDRESS, 0, 5, true, conf, null);
 
 
     TestProtocol proxy = null;
     TestProtocol proxy = null;
 
 

+ 216 - 0
src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java

@@ -0,0 +1,216 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Collection;
+
+import org.apache.commons.logging.*;
+import org.apache.commons.logging.impl.Log4JLogger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.KerberosInfo;
+import org.apache.hadoop.security.token.SecretManager;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.security.token.TokenInfo;
+import org.apache.hadoop.security.token.TokenSelector;
+import org.apache.hadoop.security.SaslInputStream;
+import org.apache.hadoop.security.SaslRpcClient;
+import org.apache.hadoop.security.SaslRpcServer;
+import org.apache.hadoop.security.UserGroupInformation;
+
+import org.apache.log4j.Level;
+import org.junit.Test;
+
+/** Unit tests for using Sasl over RPC. */
+public class TestSaslRPC {
+  private static final String ADDRESS = "0.0.0.0";
+
+  public static final Log LOG =
+    LogFactory.getLog(TestSaslRPC.class);
+  
+  static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
+  private static Configuration conf;
+  static {
+    conf = new Configuration();
+    conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos");
+    UserGroupInformation.setConfiguration(conf);
+  }
+
+  static {
+    ((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL);
+    ((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL);
+    ((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL);
+    ((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL);
+    ((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL);
+  }
+
+  public static class TestTokenIdentifier extends TokenIdentifier {
+    private Text tokenid;
+    final static Text KIND_NAME = new Text("test.token");
+    
+    public TestTokenIdentifier() {
+      this.tokenid = new Text();
+    }
+    public TestTokenIdentifier(Text tokenid) {
+      this.tokenid = tokenid;
+    }
+    @Override
+    public Text getKind() {
+      return KIND_NAME;
+    }
+    @Override
+    public Text getUsername() {
+      return tokenid;
+    }
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      tokenid.readFields(in);
+    }
+    @Override
+    public void write(DataOutput out) throws IOException {
+      tokenid.write(out);
+    }
+  }
+  
+  public static class TestTokenSecretManager extends
+      SecretManager<TestTokenIdentifier> {
+    public byte[] createPassword(TestTokenIdentifier id) {
+      return id.getBytes();
+    }
+
+    public byte[] retrievePassword(TestTokenIdentifier id) 
+        throws InvalidToken {
+      return id.getBytes();
+    }
+    
+    public TestTokenIdentifier createIdentifier() {
+      return new TestTokenIdentifier();
+    }
+  }
+
+  public static class TestTokenSelector implements
+      TokenSelector<TestTokenIdentifier> {
+    @SuppressWarnings("unchecked")
+    @Override
+    public Token<TestTokenIdentifier> selectToken(Text service,
+        Collection<Token<? extends TokenIdentifier>> tokens) {
+      if (service == null) {
+        return null;
+      }
+      for (Token<? extends TokenIdentifier> token : tokens) {
+        if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
+            && service.equals(token.getService())) {
+          return (Token<TestTokenIdentifier>) token;
+        }
+      }
+      return null;
+    }
+  }
+  
+  @KerberosInfo(SERVER_PRINCIPAL_KEY)
+  @TokenInfo(TestTokenSelector.class)
+  public interface TestSaslProtocol extends TestRPC.TestProtocol {
+  }
+  
+  public static class TestSaslImpl extends TestRPC.TestImpl implements
+      TestSaslProtocol {
+  }
+
+  @Test
+  public void testDigestRpc() throws Exception {
+    TestTokenSecretManager sm = new TestTokenSecretManager();
+    final Server server = RPC.getServer(TestSaslProtocol.class,
+        new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
+
+    server.start();
+
+    final UserGroupInformation current = UserGroupInformation.getCurrentUser();
+    final InetSocketAddress addr = NetUtils.getConnectAddress(server);
+    TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
+        .getUserName()));
+    Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
+        sm);
+    Text host = new Text(addr.getAddress().getHostAddress() + ":"
+        + addr.getPort());
+    token.setService(host);
+    LOG.info("Service IP address for token is " + host);
+    current.addToken(token);
+
+    TestSaslProtocol proxy = null;
+    try {
+      proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, conf);
+      proxy.ping();
+    } finally {
+      server.stop();
+      if (proxy != null) {
+        RPC.stopProxy(proxy);
+      }
+    }
+  }
+  
+  static void testKerberosRpc(String principal, String keytab) throws Exception {
+    final Configuration newConf = new Configuration(conf);
+    newConf.set(SERVER_PRINCIPAL_KEY, principal);
+    UserGroupInformation.loginUserFromKeytab(principal, keytab);
+    UserGroupInformation current = UserGroupInformation.getCurrentUser();
+    System.out.println("UGI: " + current);
+
+    Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(),
+        ADDRESS, 0, 5, true, newConf, null);
+    TestSaslProtocol proxy = null;
+
+    server.start();
+
+    InetSocketAddress addr = NetUtils.getConnectAddress(server);
+    try {
+      proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      proxy.ping();
+    } finally {
+      server.stop();
+      if (proxy != null) {
+        RPC.stopProxy(proxy);
+      }
+    }
+  }
+  
+  public static void main(String[] args) throws Exception {
+    System.out.println("Testing Kerberos authentication over RPC");
+    if (args.length != 2) {
+      System.err
+          .println("Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC "
+              + " <serverPrincipal> <keytabFile>");
+      System.exit(-1);
+    }
+    String principal = args[0];
+    String keytab = args[1];
+    testKerberosRpc(principal, keytab);
+  }
+
+}