浏览代码

HADOOP-2184. RPC Support for user permissions and authentication.
(Raghu Angadi via dhruba)



git-svn-id: https://svn.apache.org/repos/asf/lucene/hadoop/trunk@601221 13f79535-47bb-0310-9956-ffa450edef68

Dhruba Borthakur 17 年之前
父节点
当前提交
6fa12d241f

+ 3 - 0
CHANGES.txt

@@ -22,6 +22,9 @@ Trunk (unreleased changes)
 
     HADOOP-2288.  Enhance FileSystem API to support access control.
     (Tsz Wo (Nicholas), SZE via dhruba)
+
+    HADOOP-2184.  RPC Support for user permissions and authentication.
+    (Raghu Angadi via dhruba)
     
   NEW FEATURES
 

+ 11 - 1
src/java/org/apache/hadoop/dfs/DFSClient.java

@@ -26,6 +26,8 @@ import org.apache.hadoop.ipc.*;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.dfs.DistributedFileSystem.DiskStatus;
+import org.apache.hadoop.security.UnixUserGroupInformation;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.util.*;
 
 import org.apache.commons.logging.*;
@@ -38,6 +40,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.ConcurrentHashMap;
 
 import javax.net.SocketFactory;
+import javax.security.auth.login.LoginException;
 
 /********************************************************
  * DFSClient can connect to a Hadoop Filesystem and 
@@ -112,9 +115,16 @@ class DFSClient implements FSConstants {
     methodNameToPolicyMap.put("getEditLogSize", methodPolicy);
     methodNameToPolicyMap.put("create", methodPolicy);
 
+    UserGroupInformation userInfo;
+    try {
+      userInfo = UnixUserGroupInformation.login(conf);
+    } catch (LoginException e) {
+      throw new IOException(e.getMessage());
+    }
+
     return (ClientProtocol) RetryProxy.create(ClientProtocol.class,
         RPC.getProxy(ClientProtocol.class,
-            ClientProtocol.versionID, nameNodeAddr, conf,
+            ClientProtocol.versionID, nameNodeAddr, userInfo, conf,
             NetUtils.getSocketFactory(conf, ClientProtocol.class)),
         methodNameToPolicyMap);
   }

+ 90 - 29
src/java/org/apache/hadoop/ipc/Client.java

@@ -41,10 +41,12 @@ import org.apache.commons.logging.*;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.dfs.FSConstants;
+import org.apache.hadoop.io.ObjectWritable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.StringUtils;
 
@@ -55,14 +57,11 @@ import org.apache.hadoop.util.StringUtils;
  * @see Server
  */
 public class Client {
-  /** Should the client send the header on the connection? */
-  private static final boolean SEND_HEADER = true;
-  private static final byte CURRENT_VERSION = 0;
   
   public static final Log LOG =
     LogFactory.getLog("org.apache.hadoop.ipc.Client");
-  private Hashtable<InetSocketAddress, Connection> connections =
-    new Hashtable<InetSocketAddress, Connection>();
+  private Hashtable<ConnectionId, Connection> connections =
+    new Hashtable<ConnectionId, Connection>();
 
   private Class valueClass;                       // class of call values
   private int timeout;// timeout for calls
@@ -119,7 +118,7 @@ public class Client {
    * socket connected to a remote address.  Calls are multiplexed through this
    * socket: responses may be delivered out of order. */
   private class Connection extends Thread {
-    private InetSocketAddress address;            // address of server
+    private ConnectionId remoteId;
     private Socket socket = null;                 // connected socket
     private DataInputStream in;                   
     private DataOutputStream out;
@@ -132,11 +131,17 @@ public class Client {
     private boolean shouldCloseConnection = false;
 
     public Connection(InetSocketAddress address) throws IOException {
-      if (address.isUnresolved()) {
-        throw new UnknownHostException("unknown host: " + address.getHostName());
+      this(new ConnectionId(address, null));
+    }
+    
+    public Connection(ConnectionId remoteId) throws IOException {
+      if (remoteId.getAddress().isUnresolved()) {
+        throw new UnknownHostException("unknown host: " + 
+                                       remoteId.getAddress().getHostName());
       }
-      this.address = address;
-      this.setName("IPC Client connection to " + address.toString());
+      this.remoteId = remoteId;
+      this.setName("IPC Client connection to " + 
+                   remoteId.getAddress().toString());
       this.setDaemon(true);
     }
 
@@ -149,7 +154,7 @@ public class Client {
       while (true) {
         try {
           this.socket = socketFactory.createSocket();
-          this.socket.connect(address, FSConstants.READ_TIMEOUT);
+          this.socket.connect(remoteId.getAddress(), FSConstants.READ_TIMEOUT);
           break;
         } catch (IOException ie) { //SocketTimeoutException is also caught 
           if (failures == maxRetries) {
@@ -165,7 +170,7 @@ public class Client {
             throw ie;
           }
           failures++;
-          LOG.info("Retrying connect to server: " + address + 
+          LOG.info("Retrying connect to server: " + remoteId.getAddress() + 
                    ". Already tried " + failures + " time(s).");
           try { 
             Thread.sleep(1000);
@@ -195,13 +200,22 @@ public class Client {
                }
              }
            }));
-      if (SEND_HEADER) {
-        out.write(Server.HEADER.array());
-        out.write(CURRENT_VERSION);
-      }
+      writeHeader();
       notify();
     }
 
+    private synchronized void writeHeader() throws IOException {
+      out.write(Server.HEADER.array());
+      out.write(Server.CURRENT_VERSION);
+      //When there are more fields we can have ConnectionHeader Writable.
+      DataOutputBuffer buf = new DataOutputBuffer();
+      ObjectWritable.writeObject(buf, remoteId.getTicket(), 
+                                 UserGroupInformation.class, conf);
+      int bufLen = buf.getLength();
+      out.writeInt(bufLen);
+      out.write(buf.getData(), 0, bufLen);
+    }
+    
     private synchronized boolean waitForWork() {
       //wait till someone signals us to start reading RPC response or
       //close the connection. If we are idle long enough (blocked in wait),
@@ -238,7 +252,7 @@ public class Client {
     }
 
     public InetSocketAddress getRemoteAddress() {
-      return address;
+      return remoteId.getAddress();
     }
 
     public void setCloseConnection() {
@@ -294,8 +308,8 @@ public class Client {
         //We don't want to remove this again as some other thread might have
         //actually put a new Connection object in the table in the meantime.
         synchronized (connections) {
-          if (connections.get(address) == this) {
-            connections.remove(address);
+          if (connections.get(remoteId) == this) {
+            connections.remove(remoteId);
           }
         }
         close();
@@ -333,8 +347,8 @@ public class Client {
       } finally {
         if (error) {
           synchronized (connections) {
-            if (connections.get(address) == this)
-              connections.remove(address);
+            if (connections.get(remoteId) == this)
+              connections.remove(remoteId);
           }
           close();                                // close on error
         }
@@ -467,8 +481,14 @@ public class Client {
    * <code>address</code>, returning the value.  Throws exceptions if there are
    * network problems or if the remote code threw an exception. */
   public Writable call(Writable param, InetSocketAddress address)
-    throws InterruptedException, IOException {
-    Connection connection = getConnection(address);
+  throws InterruptedException, IOException {
+      return call(param, address, null);
+  }
+  
+  public Writable call(Writable param, InetSocketAddress addr, 
+                       UserGroupInformation ticket)  
+                       throws InterruptedException, IOException {
+    Connection connection = getConnection(addr, ticket);
     Call call = new Call(param);
     synchronized (call) {
       connection.sendParam(call);                 // send the parameter
@@ -501,7 +521,7 @@ public class Client {
       for (int i = 0; i < params.length; i++) {
         ParallelCall call = new ParallelCall(params[i], results, i);
         try {
-          Connection connection = getConnection(addresses[i]);
+          Connection connection = getConnection(addresses[i], null);
           connection.sendParam(call);             // send each parameter
         } catch (IOException e) {
           LOG.info("Calling "+addresses[i]+" caught: " + 
@@ -523,14 +543,20 @@ public class Client {
 
   /** Get a connection from the pool, or create a new one and add it to the
    * pool.  Connections to a given host/port are reused. */
-  private Connection getConnection(InetSocketAddress address)
-    throws IOException {
+  private Connection getConnection(InetSocketAddress addr, 
+                                   UserGroupInformation ticket)
+                                   throws IOException {
     Connection connection;
+    /* we could avoid this allocation for each RPC by having a  
+     * connectionsId object and with set() method. We need to manage the
+     * refs for keys in HashMap properly. For now its ok.
+     */
+    ConnectionId remoteId = new ConnectionId(addr, ticket);
     synchronized (connections) {
-      connection = connections.get(address);
+      connection = connections.get(remoteId);
       if (connection == null) {
-        connection = new Connection(address);
-        connections.put(address, connection);
+        connection = new Connection(remoteId);
+        connections.put(remoteId, connection);
         connection.start();
       }
       connection.incrementRef();
@@ -543,4 +569,39 @@ public class Client {
     return connection;
   }
 
+  /**
+   * This class holds the address and the user ticket. The client connections
+   * to servers are uniquely identified by <remoteAddress, ticket>
+   */
+  private static class ConnectionId {
+    InetSocketAddress address;
+    UserGroupInformation ticket;
+    
+    ConnectionId(InetSocketAddress address, UserGroupInformation ticket) {
+      this.address = address;
+      this.ticket = ticket;
+    }
+    
+    InetSocketAddress getAddress() {
+      return address;
+    }
+    UserGroupInformation getTicket() {
+      return ticket;
+    }
+    
+    @Override
+    public boolean equals(Object obj) {
+     if (obj instanceof ConnectionId) {
+       ConnectionId id = (ConnectionId) obj;
+       return address.equals(id.address) && ticket == id.ticket;
+       //Note : ticket is a ref comparision.
+     }
+     return false;
+    }
+    
+    @Override
+    public int hashCode() {
+      return address.hashCode() ^ System.identityHashCode(ticket);
+    }
+  }  
 }

+ 15 - 5
src/java/org/apache/hadoop/ipc/RPC.java

@@ -37,6 +37,7 @@ import org.apache.commons.logging.*;
 
 import org.apache.hadoop.io.*;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.conf.*;
 
 /** A simple RPC mechanism.
@@ -169,12 +170,13 @@ public class RPC {
 
   private static class Invoker implements InvocationHandler {
     private InetSocketAddress address;
+    private UserGroupInformation ticket;
     private Client client;
 
-    public Invoker(InetSocketAddress address, Configuration conf,
-        SocketFactory factory) {
-
+    public Invoker(InetSocketAddress address, UserGroupInformation ticket, 
+                   Configuration conf, SocketFactory factory) {
       this.address = address;
+      this.ticket = ticket;
       this.client = getClient(conf, factory);
     }
 
@@ -182,7 +184,7 @@ public class RPC {
       throws Throwable {
       long startTime = System.currentTimeMillis();
       ObjectWritable value = (ObjectWritable)
-        client.call(new Invocation(method, args), address);
+        client.call(new Invocation(method, args), address, ticket);
       long callTime = System.currentTimeMillis() - startTime;
       LOG.debug("Call: " + method.getName() + " " + callTime);
       return value.get();
@@ -261,11 +263,19 @@ public class RPC {
   public static VersionedProtocol getProxy(Class<?> protocol,
       long clientVersion, InetSocketAddress addr, Configuration conf,
       SocketFactory factory) throws IOException {
+    return getProxy(protocol, clientVersion, addr, null, conf, factory);
+  }
+  
+  /** Construct a client-side proxy object that implements the named protocol,
+   * talking to a server at the named address. */
+  public static VersionedProtocol getProxy(Class<?> protocol,
+      long clientVersion, InetSocketAddress addr, UserGroupInformation ticket,
+      Configuration conf, SocketFactory factory) throws IOException {    
 
     VersionedProtocol proxy =
         (VersionedProtocol) Proxy.newProxyInstance(
             protocol.getClassLoader(), new Class[] { protocol },
-            new Invoker(addr, conf, factory));
+            new Invoker(addr, ticket, conf, factory));
     long serverVersion = proxy.getProtocolVersion(protocol.getName(), 
                                                   clientVersion);
     if (serverVersion == clientVersion) {

+ 76 - 28
src/java/org/apache/hadoop/ipc/Server.java

@@ -35,7 +35,6 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
-import java.net.SocketAddress;
 import java.net.SocketException;
 import java.net.UnknownHostException;
 
@@ -48,9 +47,11 @@ import java.util.Random;
 import org.apache.commons.logging.*;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.ObjectWritable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.ipc.SocketChannelOutputStream;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.util.*;
 
 /** An abstract IPC service.  IPC calls take a single {@link Writable} as a
@@ -66,6 +67,9 @@ public abstract class Server {
    */
   public static final ByteBuffer HEADER = ByteBuffer.wrap("hrpc".getBytes());
   
+  // 1 : Ticket is added to connection header
+  public static final byte CURRENT_VERSION = 1;
+  
   /**
    * How much time should be allocated for actually running the handler?
    * Calls that are older than ipc.timeout * MAX_CALL_QUEUE_TIME
@@ -113,6 +117,14 @@ public abstract class Server {
     InetAddress addr = getRemoteIp();
     return (addr == null) ? null : addr.getHostAddress();
   }
+
+  /** Returns {@link UserGroupInformation} associated with current RPC.
+   *  returns null if user information is not available.
+   */
+  public static UserGroupInformation getUserInfo() {
+    Call call = CurCall.get();
+    return (call == null) ? null : call.connection.ticket;
+  }
   
   private String bindAddress; 
   private int port;                               // port we listen on
@@ -418,7 +430,10 @@ public abstract class Server {
 
   /** Reads calls from a connection and queues them for handling. */
   private class Connection {
-    private boolean firstData = true;
+    private boolean versionRead = false; //if initial signature and
+                                         //version are read
+    private boolean headerRead = false;  //if the connection header that
+                                         //follows version is read.
     private SocketChannel channel;
     private SelectionKey key;
     private ByteBuffer data;
@@ -432,6 +447,7 @@ public abstract class Server {
     // disconnected, we can say where it used to connect to.
     private String hostAddress;
     private int remotePort;
+    private UserGroupInformation ticket = null;
 
     public Connection(SelectionKey key, SocketChannel channel, 
                       long lastContact) {
@@ -476,42 +492,74 @@ public abstract class Server {
     }
 
     public int readAndProcess() throws IOException, InterruptedException {
-      int count = -1;
-      if (dataLengthBuffer.remaining() > 0) {
-        count = channel.read(dataLengthBuffer);       
-        if (count < 0 || dataLengthBuffer.remaining() > 0) 
-          return count;        
-        dataLengthBuffer.flip(); 
-        // Is this a new style header?
-        if (firstData && HEADER.equals(dataLengthBuffer)) {
-          // If so, read the version
+      while (true) {
+        /* Read at most one RPC. If the header is not read completely yet
+         * then iterate until we read first RPC or until there is no data left.
+         */    
+        int count = -1;
+        if (dataLengthBuffer.remaining() > 0) {
+          count = channel.read(dataLengthBuffer);       
+          if (count < 0 || dataLengthBuffer.remaining() > 0) 
+            return count;
+        }
+      
+        if (!versionRead) {
+          //Every connection is expected to send the header.
           ByteBuffer versionBuffer = ByteBuffer.allocate(1);
           count = channel.read(versionBuffer);
-          if (count < 0) {
+          if (count <= 0) {
             return count;
           }
-          // read the first length
-          dataLengthBuffer.clear();
-          count = channel.read(dataLengthBuffer);
-          if (count < 0 || dataLengthBuffer.remaining() > 0) {
-            return count;
+          int version = versionBuffer.get(0);
+          
+          dataLengthBuffer.flip();          
+          if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
+            //Warning is ok since this is not supposed to happen.
+            LOG.warn("Incorrect header or version mismatch from " + 
+                     hostAddress + ":" + remotePort);
+            return -1;
           }
+          dataLengthBuffer.clear();
+          versionRead = true;
+          continue;
+        }
+        
+        if (data == null) {
           dataLengthBuffer.flip();
-          firstData = false;
+          dataLength = dataLengthBuffer.getInt();
+          data = ByteBuffer.allocate(dataLength);
         }
-        dataLength = dataLengthBuffer.getInt();
-        data = ByteBuffer.allocate(dataLength);
-      }
-      count = channel.read(data);
-      if (data.remaining() == 0) {
-        data.flip();
-        processData();
-        dataLengthBuffer.flip();
-        data = null; 
+        
+        count = channel.read(data);
+        
+        if (data.remaining() == 0) {
+          dataLengthBuffer.clear();
+          data.flip();
+          if (headerRead) {
+            processData();
+            data = null;
+            return count;
+          } else {
+            processHeader();
+            headerRead = true;
+            data = null;
+            continue;
+          }
+        } 
+        return count;
       }
-      return count;
     }
 
+    /// Reads the header following version
+    private void processHeader() throws IOException {
+      /* In the current version, it is just a ticket.
+       * Later we could introduce a "ConnectionHeader" class.
+       */
+      DataInputStream in =
+        new DataInputStream(new ByteArrayInputStream(data.array()));
+      ticket = (UserGroupInformation) ObjectWritable.readObject(in, conf);
+    }
+    
     private void processData() throws  IOException, InterruptedException {
       DataInputStream dis =
         new DataInputStream(new ByteArrayInputStream(data.array()));