|
@@ -32,6 +32,7 @@ import java.net.SocketException;
|
|
|
import java.net.UnknownHostException;
|
|
|
import java.nio.ByteBuffer;
|
|
|
import java.nio.channels.CancelledKeyException;
|
|
|
+import java.nio.channels.Channels;
|
|
|
import java.nio.channels.ClosedChannelException;
|
|
|
import java.nio.channels.ReadableByteChannel;
|
|
|
import java.nio.channels.SelectionKey;
|
|
@@ -39,7 +40,6 @@ import java.nio.channels.Selector;
|
|
|
import java.nio.channels.ServerSocketChannel;
|
|
|
import java.nio.channels.SocketChannel;
|
|
|
import java.nio.channels.WritableByteChannel;
|
|
|
-import java.security.PrivilegedActionException;
|
|
|
import java.security.PrivilegedExceptionAction;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
@@ -52,15 +52,26 @@ import java.util.concurrent.BlockingQueue;
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
import java.util.concurrent.LinkedBlockingQueue;
|
|
|
|
|
|
+import javax.security.sasl.Sasl;
|
|
|
+import javax.security.sasl.SaslException;
|
|
|
+import javax.security.sasl.SaslServer;
|
|
|
+
|
|
|
import org.apache.commons.logging.Log;
|
|
|
import org.apache.commons.logging.LogFactory;
|
|
|
import org.apache.hadoop.conf.Configuration;
|
|
|
import org.apache.hadoop.io.Writable;
|
|
|
import org.apache.hadoop.io.WritableUtils;
|
|
|
import org.apache.hadoop.ipc.metrics.RpcMetrics;
|
|
|
+import org.apache.hadoop.security.AccessControlException;
|
|
|
+import org.apache.hadoop.security.SaslRpcServer;
|
|
|
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
|
|
|
+import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler;
|
|
|
+import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler;
|
|
|
import org.apache.hadoop.security.UserGroupInformation;
|
|
|
import org.apache.hadoop.security.authorize.AuthorizationException;
|
|
|
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
|
|
|
+import org.apache.hadoop.security.token.TokenIdentifier;
|
|
|
+import org.apache.hadoop.security.token.SecretManager;
|
|
|
import org.apache.hadoop.util.ReflectionUtils;
|
|
|
import org.apache.hadoop.util.StringUtils;
|
|
|
|
|
@@ -80,7 +91,8 @@ public abstract class Server {
|
|
|
|
|
|
// 1 : Introduce ping and server does not throw away RPCs
|
|
|
// 3 : Introduce the protocol into the RPC connection header
|
|
|
- public static final byte CURRENT_VERSION = 3;
|
|
|
+ // 4 : Introduced SASL security layer
|
|
|
+ public static final byte CURRENT_VERSION = 4;
|
|
|
|
|
|
/**
|
|
|
* How many calls/handler are allowed in the queue.
|
|
@@ -152,6 +164,7 @@ public abstract class Server {
|
|
|
protected RpcMetrics rpcMetrics;
|
|
|
|
|
|
private Configuration conf;
|
|
|
+ private SecretManager<TokenIdentifier> secretManager;
|
|
|
|
|
|
private int maxQueueSize;
|
|
|
private int socketSendBufferSize;
|
|
@@ -424,7 +437,7 @@ public abstract class Server {
|
|
|
if (count < 0) {
|
|
|
if (LOG.isDebugEnabled())
|
|
|
LOG.debug(getName() + ": disconnecting client " +
|
|
|
- c.getHostAddress() + ". Number of active connections: "+
|
|
|
+ c + ". Number of active connections: "+
|
|
|
numConnections);
|
|
|
closeConnection(c);
|
|
|
c = null;
|
|
@@ -694,8 +707,7 @@ public abstract class Server {
|
|
|
|
|
|
/** Reads calls from a connection and queues them for handling. */
|
|
|
private class Connection {
|
|
|
- private boolean versionRead = false; //if initial signature and
|
|
|
- //version are read
|
|
|
+ private boolean rpcHeaderRead = false; // if initial rpc header is read
|
|
|
private boolean headerRead = false; //if the connection header that
|
|
|
//follows version is read.
|
|
|
|
|
@@ -714,6 +726,13 @@ public abstract class Server {
|
|
|
|
|
|
ConnectionHeader header = new ConnectionHeader();
|
|
|
Class<?> protocol;
|
|
|
+ boolean useSasl;
|
|
|
+ SaslServer saslServer;
|
|
|
+ private AuthMethod authMethod;
|
|
|
+ private boolean saslContextEstablished;
|
|
|
+ private ByteBuffer rpcHeaderBuffer;
|
|
|
+ private ByteBuffer unwrappedData;
|
|
|
+ private ByteBuffer unwrappedDataLengthBuffer;
|
|
|
|
|
|
UserGroupInformation user = null;
|
|
|
|
|
@@ -722,6 +741,10 @@ public abstract class Server {
|
|
|
private final Call authFailedCall =
|
|
|
new Call(AUTHROIZATION_FAILED_CALLID, null, null);
|
|
|
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
|
|
|
+ // Fake 'call' for SASL context setup
|
|
|
+ private static final int SASL_CALLID = -33;
|
|
|
+ private final Call saslCall = new Call(SASL_CALLID, null, null);
|
|
|
+ private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
|
|
|
|
|
|
public Connection(SelectionKey key, SocketChannel channel,
|
|
|
long lastContact) {
|
|
@@ -729,6 +752,8 @@ public abstract class Server {
|
|
|
this.lastContact = lastContact;
|
|
|
this.data = null;
|
|
|
this.dataLengthBuffer = ByteBuffer.allocate(4);
|
|
|
+ this.unwrappedData = null;
|
|
|
+ this.unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
|
|
|
this.socket = channel.socket();
|
|
|
InetAddress addr = socket.getInetAddress();
|
|
|
if (addr == null) {
|
|
@@ -786,6 +811,92 @@ public abstract class Server {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+ private void saslReadAndProcess(byte[] saslToken) throws IOException,
|
|
|
+ InterruptedException {
|
|
|
+ if (!saslContextEstablished) {
|
|
|
+ if (saslServer == null) {
|
|
|
+ switch (authMethod) {
|
|
|
+ case DIGEST:
|
|
|
+ saslServer = Sasl.createSaslServer(AuthMethod.DIGEST
|
|
|
+ .getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM,
|
|
|
+ SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler(
|
|
|
+ secretManager));
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ UserGroupInformation current = UserGroupInformation
|
|
|
+ .getCurrentUser();
|
|
|
+ String fullName = current.getUserName();
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Kerberos principal name is " + fullName);
|
|
|
+ final String names[] = SaslRpcServer.splitKerberosName(fullName);
|
|
|
+ if (names.length != 3) {
|
|
|
+ throw new IOException(
|
|
|
+ "Kerberos principal name does NOT have the expected "
|
|
|
+ + "hostname part: " + fullName);
|
|
|
+ }
|
|
|
+ current.doAs(new PrivilegedExceptionAction<Object>() {
|
|
|
+ @Override
|
|
|
+ public Object run() throws IOException {
|
|
|
+ saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS
|
|
|
+ .getMechanismName(), names[0], names[1],
|
|
|
+ SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler());
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ if (saslServer == null)
|
|
|
+ throw new IOException(
|
|
|
+ "Unable to find SASL server implementation for "
|
|
|
+ + authMethod.getMechanismName());
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Created SASL server with mechanism = "
|
|
|
+ + authMethod.getMechanismName());
|
|
|
+ }
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Have read input token of size " + saslToken.length
|
|
|
+ + " for processing by saslServer.evaluateResponse()");
|
|
|
+ byte[] replyToken = saslServer.evaluateResponse(saslToken);
|
|
|
+ if (replyToken != null) {
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Will send token of size " + replyToken.length
|
|
|
+ + " from saslServer.");
|
|
|
+ saslCall.connection = this;
|
|
|
+ saslResponse.reset();
|
|
|
+ DataOutputStream out = new DataOutputStream(saslResponse);
|
|
|
+ out.writeInt(replyToken.length);
|
|
|
+ out.write(replyToken, 0, replyToken.length);
|
|
|
+ saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray()));
|
|
|
+ responder.doRespond(saslCall);
|
|
|
+ }
|
|
|
+ if (saslServer.isComplete()) {
|
|
|
+ if (LOG.isDebugEnabled()) {
|
|
|
+ LOG.debug("SASL server context established. Negotiated QoP is "
|
|
|
+ + saslServer.getNegotiatedProperty(Sasl.QOP));
|
|
|
+ }
|
|
|
+ user = UserGroupInformation.createRemoteUser(saslServer
|
|
|
+ .getAuthorizationID());
|
|
|
+ LOG.info("SASL server successfully authenticated client: " + user);
|
|
|
+ saslContextEstablished = true;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Have read input token of size " + saslToken.length
|
|
|
+ + " for processing by saslServer.unwrap()");
|
|
|
+ byte[] plaintextData = saslServer
|
|
|
+ .unwrap(saslToken, 0, saslToken.length);
|
|
|
+ processUnwrappedData(plaintextData);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void disposeSasl() {
|
|
|
+ if (saslServer != null) {
|
|
|
+ try {
|
|
|
+ saslServer.dispose();
|
|
|
+ } catch (SaslException ignored) {
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public int readAndProcess() throws IOException, InterruptedException {
|
|
|
while (true) {
|
|
|
/* Read at most one RPC. If the header is not read completely yet
|
|
@@ -798,14 +909,33 @@ public abstract class Server {
|
|
|
return count;
|
|
|
}
|
|
|
|
|
|
- if (!versionRead) {
|
|
|
+ if (!rpcHeaderRead) {
|
|
|
//Every connection is expected to send the header.
|
|
|
- ByteBuffer versionBuffer = ByteBuffer.allocate(1);
|
|
|
- count = channelRead(channel, versionBuffer);
|
|
|
- if (count <= 0) {
|
|
|
+ if (rpcHeaderBuffer == null) {
|
|
|
+ rpcHeaderBuffer = ByteBuffer.allocate(2);
|
|
|
+ }
|
|
|
+ count = channelRead(channel, rpcHeaderBuffer);
|
|
|
+ if (count < 0 || rpcHeaderBuffer.remaining() > 0) {
|
|
|
return count;
|
|
|
}
|
|
|
- int version = versionBuffer.get(0);
|
|
|
+ int version = rpcHeaderBuffer.get(0);
|
|
|
+ byte[] method = new byte[] {rpcHeaderBuffer.get(1)};
|
|
|
+ authMethod = AuthMethod.read(new DataInputStream(
|
|
|
+ new ByteArrayInputStream(method)));
|
|
|
+ if (authMethod == null) {
|
|
|
+ throw new IOException("Unable to read authentication method");
|
|
|
+ }
|
|
|
+ if (UserGroupInformation.isSecurityEnabled()
|
|
|
+ && authMethod == AuthMethod.SIMPLE) {
|
|
|
+ throw new IOException("Authentication is required");
|
|
|
+ }
|
|
|
+ if (!UserGroupInformation.isSecurityEnabled()
|
|
|
+ && authMethod != AuthMethod.SIMPLE) {
|
|
|
+ throw new IOException("Authentication is not supported");
|
|
|
+ }
|
|
|
+ if (authMethod != AuthMethod.SIMPLE) {
|
|
|
+ useSasl = true;
|
|
|
+ }
|
|
|
|
|
|
dataLengthBuffer.flip();
|
|
|
if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
|
|
@@ -817,7 +947,8 @@ public abstract class Server {
|
|
|
return -1;
|
|
|
}
|
|
|
dataLengthBuffer.clear();
|
|
|
- versionRead = true;
|
|
|
+ rpcHeaderBuffer = null;
|
|
|
+ rpcHeaderRead = true;
|
|
|
continue;
|
|
|
}
|
|
|
|
|
@@ -825,12 +956,11 @@ public abstract class Server {
|
|
|
dataLengthBuffer.flip();
|
|
|
dataLength = dataLengthBuffer.getInt();
|
|
|
|
|
|
- if (dataLength == Client.PING_CALL_ID) {
|
|
|
+ if (!useSasl && dataLength == Client.PING_CALL_ID) {
|
|
|
dataLengthBuffer.clear();
|
|
|
return 0; //ping message
|
|
|
}
|
|
|
data = ByteBuffer.allocate(dataLength);
|
|
|
- incRpcCount(); // Increment the rpc count
|
|
|
}
|
|
|
|
|
|
count = channelRead(channel, data);
|
|
@@ -838,33 +968,14 @@ public abstract class Server {
|
|
|
if (data.remaining() == 0) {
|
|
|
dataLengthBuffer.clear();
|
|
|
data.flip();
|
|
|
- if (headerRead) {
|
|
|
- processData();
|
|
|
- data = null;
|
|
|
- return count;
|
|
|
+ boolean isHeaderRead = headerRead;
|
|
|
+ if (useSasl) {
|
|
|
+ saslReadAndProcess(data.array());
|
|
|
} else {
|
|
|
- processHeader();
|
|
|
- headerRead = true;
|
|
|
- data = null;
|
|
|
-
|
|
|
- // Authorize the connection
|
|
|
- try {
|
|
|
- authorize(user, header);
|
|
|
-
|
|
|
- if (LOG.isDebugEnabled()) {
|
|
|
- LOG.debug("Successfully authorized " + header);
|
|
|
- }
|
|
|
- } catch (AuthorizationException ae) {
|
|
|
- authFailedCall.connection = this;
|
|
|
- setupResponse(authFailedResponse, authFailedCall,
|
|
|
- Status.FATAL, null,
|
|
|
- ae.getClass().getName(), ae.getMessage());
|
|
|
- responder.doRespond(authFailedCall);
|
|
|
-
|
|
|
- // Close this connection
|
|
|
- return -1;
|
|
|
- }
|
|
|
-
|
|
|
+ processOneRpc(data.array());
|
|
|
+ }
|
|
|
+ data = null;
|
|
|
+ if (!isHeaderRead) {
|
|
|
continue;
|
|
|
}
|
|
|
}
|
|
@@ -873,9 +984,9 @@ public abstract class Server {
|
|
|
}
|
|
|
|
|
|
/// Reads the connection header following version
|
|
|
- private void processHeader() throws IOException {
|
|
|
+ private void processHeader(byte[] buf) throws IOException {
|
|
|
DataInputStream in =
|
|
|
- new DataInputStream(new ByteArrayInputStream(data.array()));
|
|
|
+ new DataInputStream(new ByteArrayInputStream(buf));
|
|
|
header.readFields(in);
|
|
|
try {
|
|
|
String protocolClassName = header.getProtocol();
|
|
@@ -886,12 +997,73 @@ public abstract class Server {
|
|
|
throw new IOException("Unknown protocol: " + header.getProtocol());
|
|
|
}
|
|
|
|
|
|
- user = header.getUgi();
|
|
|
+ UserGroupInformation protocolUser = header.getUgi();
|
|
|
+ if (!useSasl) {
|
|
|
+ user = protocolUser;
|
|
|
+ } else if (protocolUser != null && !protocolUser.equals(user)) {
|
|
|
+ throw new AccessControlException("Authenticated user (" + user
|
|
|
+ + ") doesn't match what the client claims to be (" + protocolUser
|
|
|
+ + ")");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void processUnwrappedData(byte[] inBuf) throws IOException,
|
|
|
+ InterruptedException {
|
|
|
+ ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
|
|
|
+ inBuf));
|
|
|
+ // Read all RPCs contained in the inBuf, even partial ones
|
|
|
+ while (true) {
|
|
|
+ int count = -1;
|
|
|
+ if (unwrappedDataLengthBuffer.remaining() > 0) {
|
|
|
+ count = channelRead(ch, unwrappedDataLengthBuffer);
|
|
|
+ if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0)
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (unwrappedData == null) {
|
|
|
+ unwrappedDataLengthBuffer.flip();
|
|
|
+ int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
|
|
|
+
|
|
|
+ if (unwrappedDataLength == Client.PING_CALL_ID) {
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Received ping message");
|
|
|
+ unwrappedDataLengthBuffer.clear();
|
|
|
+ continue; // ping message
|
|
|
+ }
|
|
|
+ unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
|
|
|
+ }
|
|
|
+
|
|
|
+ count = channelRead(ch, unwrappedData);
|
|
|
+ if (count <= 0 || unwrappedData.remaining() > 0)
|
|
|
+ return;
|
|
|
+
|
|
|
+ if (unwrappedData.remaining() == 0) {
|
|
|
+ unwrappedDataLengthBuffer.clear();
|
|
|
+ unwrappedData.flip();
|
|
|
+ processOneRpc(unwrappedData.array());
|
|
|
+ unwrappedData = null;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private void processData() throws IOException, InterruptedException {
|
|
|
+ private void processOneRpc(byte[] buf) throws IOException,
|
|
|
+ InterruptedException {
|
|
|
+ if (headerRead) {
|
|
|
+ processData(buf);
|
|
|
+ } else {
|
|
|
+ processHeader(buf);
|
|
|
+ headerRead = true;
|
|
|
+ if (!authorizeConnection()) {
|
|
|
+ throw new AccessControlException("Connection from " + this
|
|
|
+ + " for protocol " + header.getProtocol()
|
|
|
+ + " is unauthorized for user " + user);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void processData(byte[] buf) throws IOException, InterruptedException {
|
|
|
DataInputStream dis =
|
|
|
- new DataInputStream(new ByteArrayInputStream(data.array()));
|
|
|
+ new DataInputStream(new ByteArrayInputStream(buf));
|
|
|
int id = dis.readInt(); // try to read an id
|
|
|
|
|
|
if (LOG.isDebugEnabled())
|
|
@@ -902,9 +1074,27 @@ public abstract class Server {
|
|
|
|
|
|
Call call = new Call(id, param, this);
|
|
|
callQueue.put(call); // queue the call; maybe blocked here
|
|
|
+ incRpcCount(); // Increment the rpc count
|
|
|
}
|
|
|
|
|
|
+ private boolean authorizeConnection() throws IOException {
|
|
|
+ try {
|
|
|
+ authorize(user, header);
|
|
|
+ if (LOG.isDebugEnabled()) {
|
|
|
+ LOG.debug("Successfully authorized " + header);
|
|
|
+ }
|
|
|
+ } catch (AuthorizationException ae) {
|
|
|
+ authFailedCall.connection = this;
|
|
|
+ setupResponse(authFailedResponse, authFailedCall, Status.FATAL, null,
|
|
|
+ ae.getClass().getName(), ae.getMessage());
|
|
|
+ responder.doRespond(authFailedCall);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
private synchronized void close() throws IOException {
|
|
|
+ disposeSasl();
|
|
|
data = null;
|
|
|
dataLengthBuffer = null;
|
|
|
if (!channel.isOpen())
|
|
@@ -993,16 +1183,17 @@ public abstract class Server {
|
|
|
Configuration conf)
|
|
|
throws IOException
|
|
|
{
|
|
|
- this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port));
|
|
|
+ this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port), null);
|
|
|
}
|
|
|
/** Constructs a server listening on the named port and address. Parameters passed must
|
|
|
* be of the named class. The <code>handlerCount</handlerCount> determines
|
|
|
* the number of handler threads that will be used to process calls.
|
|
|
*
|
|
|
*/
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
protected Server(String bindAddress, int port,
|
|
|
Class<? extends Writable> paramClass, int handlerCount,
|
|
|
- Configuration conf, String serverName)
|
|
|
+ Configuration conf, String serverName, SecretManager<? extends TokenIdentifier> secretManager)
|
|
|
throws IOException {
|
|
|
this.bindAddress = bindAddress;
|
|
|
this.conf = conf;
|
|
@@ -1015,6 +1206,7 @@ public abstract class Server {
|
|
|
this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000);
|
|
|
this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10);
|
|
|
this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000);
|
|
|
+ this.secretManager = (SecretManager<TokenIdentifier>) secretManager;
|
|
|
this.authorize =
|
|
|
conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG,
|
|
|
false);
|
|
@@ -1068,9 +1260,29 @@ public abstract class Server {
|
|
|
WritableUtils.writeString(out, errorClass);
|
|
|
WritableUtils.writeString(out, error);
|
|
|
}
|
|
|
+ wrapWithSasl(response, call);
|
|
|
call.setResponse(ByteBuffer.wrap(response.toByteArray()));
|
|
|
}
|
|
|
|
|
|
+ private void wrapWithSasl(ByteArrayOutputStream response, Call call)
|
|
|
+ throws IOException {
|
|
|
+ if (call.connection.useSasl) {
|
|
|
+ byte[] token = response.toByteArray();
|
|
|
+ // synchronization may be needed since there can be multiple Handler
|
|
|
+ // threads using saslServer to wrap responses.
|
|
|
+ synchronized (call.connection.saslServer) {
|
|
|
+ token = call.connection.saslServer.wrap(token, 0, token.length);
|
|
|
+ }
|
|
|
+ if (LOG.isDebugEnabled())
|
|
|
+ LOG.debug("Adding saslServer wrapped token of size " + token.length
|
|
|
+ + " as call response.");
|
|
|
+ response.reset();
|
|
|
+ DataOutputStream saslOut = new DataOutputStream(response);
|
|
|
+ saslOut.writeInt(token.length);
|
|
|
+ saslOut.write(token, 0, token.length);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
Configuration getConf() {
|
|
|
return conf;
|
|
|
}
|