|
@@ -37,6 +37,7 @@ import java.security.PrivilegedExceptionAction;
|
|
|
import java.util.Hashtable;
|
|
|
import java.util.Iterator;
|
|
|
import java.util.Random;
|
|
|
+import java.util.Set;
|
|
|
import java.util.Map.Entry;
|
|
|
import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
import java.util.concurrent.atomic.AtomicLong;
|
|
@@ -45,6 +46,8 @@ import javax.net.SocketFactory;
|
|
|
|
|
|
import org.apache.commons.logging.*;
|
|
|
|
|
|
+import org.apache.hadoop.classification.InterfaceAudience;
|
|
|
+import org.apache.hadoop.classification.InterfaceStability;
|
|
|
import org.apache.hadoop.conf.Configuration;
|
|
|
import org.apache.hadoop.io.IOUtils;
|
|
|
import org.apache.hadoop.io.Text;
|
|
@@ -80,12 +83,6 @@ public class Client {
|
|
|
private int counter; // counter for call ids
|
|
|
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
|
|
|
final private Configuration conf;
|
|
|
- final private int maxIdleTime; //connections will be culled if it was idle for
|
|
|
- //maxIdleTime msecs
|
|
|
- final private int maxRetries; //the max. no. of retries for socket connections
|
|
|
- private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
|
|
- private int pingInterval; // how often sends ping to the server in msecs
|
|
|
- final private boolean doPing; //do we need to send ping message
|
|
|
|
|
|
private SocketFactory socketFactory; // how to create sockets
|
|
|
private int refCount = 1;
|
|
@@ -220,6 +217,12 @@ public class Client {
|
|
|
private DataInputStream in;
|
|
|
private DataOutputStream out;
|
|
|
private int rpcTimeout;
|
|
|
+ private int maxIdleTime; //connections will be culled if it was idle for
|
|
|
+ //maxIdleTime msecs
|
|
|
+ private int maxRetries; //the max. no. of retries for socket connections
|
|
|
+ private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
|
|
+ private boolean doPing; //do we need to send ping message
|
|
|
+ private int pingInterval; // how often sends ping to the server in msecs
|
|
|
|
|
|
// currently active calls
|
|
|
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
|
|
@@ -235,6 +238,15 @@ public class Client {
|
|
|
remoteId.getAddress().getHostName());
|
|
|
}
|
|
|
this.rpcTimeout = remoteId.getRpcTimeout();
|
|
|
+ this.maxIdleTime = remoteId.getMaxIdleTime();
|
|
|
+ this.maxRetries = remoteId.getMaxRetries();
|
|
|
+ this.tcpNoDelay = remoteId.getTcpNoDelay();
|
|
|
+ this.doPing = remoteId.getDoPing();
|
|
|
+ this.pingInterval = remoteId.getPingInterval();
|
|
|
+ if (LOG.isDebugEnabled()) {
|
|
|
+ LOG.debug("The ping interval is" + this.pingInterval + "ms.");
|
|
|
+ }
|
|
|
+
|
|
|
UserGroupInformation ticket = remoteId.getTicket();
|
|
|
Class<?> protocol = remoteId.getProtocol();
|
|
|
this.useSasl = UserGroupInformation.isSecurityEnabled();
|
|
@@ -256,15 +268,9 @@ public class Client {
|
|
|
}
|
|
|
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
|
|
|
if (krbInfo != null) {
|
|
|
- String serverKey = krbInfo.serverPrincipal();
|
|
|
- if (serverKey == null) {
|
|
|
- throw new IOException(
|
|
|
- "Can't obtain server Kerberos config key from KerberosInfo");
|
|
|
- }
|
|
|
- serverPrincipal = SecurityUtil.getServerPrincipal(
|
|
|
- conf.get(serverKey), server.getAddress().getCanonicalHostName());
|
|
|
+ serverPrincipal = remoteId.getServerPrincipal();
|
|
|
if (LOG.isDebugEnabled()) {
|
|
|
- LOG.debug("RPC Server Kerberos principal name for protocol="
|
|
|
+ LOG.debug("RPC Server's Kerberos principal name for protocol="
|
|
|
+ protocol.getCanonicalName() + " is " + serverPrincipal);
|
|
|
}
|
|
|
}
|
|
@@ -882,15 +888,6 @@ public class Client {
|
|
|
public Client(Class<? extends Writable> valueClass, Configuration conf,
|
|
|
SocketFactory factory) {
|
|
|
this.valueClass = valueClass;
|
|
|
- this.maxIdleTime =
|
|
|
- conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
|
|
|
- this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
|
|
|
- this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
|
|
|
- this.doPing = conf.getBoolean("ipc.client.ping", true);
|
|
|
- this.pingInterval = getPingInterval(conf);
|
|
|
- if (LOG.isDebugEnabled()) {
|
|
|
- LOG.debug("The ping interval is" + this.pingInterval + "ms.");
|
|
|
- }
|
|
|
this.conf = conf;
|
|
|
this.socketFactory = factory;
|
|
|
}
|
|
@@ -942,7 +939,7 @@ public class Client {
|
|
|
/** Make a call, passing <code>param</code>, to the IPC server running at
|
|
|
* <code>address</code>, returning the value. Throws exceptions if there are
|
|
|
* network problems or if the remote code threw an exception.
|
|
|
- * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
|
|
|
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
|
|
*/
|
|
|
@Deprecated
|
|
|
public Writable call(Writable param, InetSocketAddress address)
|
|
@@ -955,27 +952,60 @@ public class Client {
|
|
|
* the value.
|
|
|
* Throws exceptions if there are network problems or if the remote code
|
|
|
* threw an exception.
|
|
|
- * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
|
|
|
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
|
|
*/
|
|
|
@Deprecated
|
|
|
public Writable call(Writable param, InetSocketAddress addr,
|
|
|
UserGroupInformation ticket)
|
|
|
throws InterruptedException, IOException {
|
|
|
- return call(param, addr, null, ticket, 0);
|
|
|
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0,
|
|
|
+ conf);
|
|
|
+ return call(param, remoteId);
|
|
|
}
|
|
|
|
|
|
/** Make a call, passing <code>param</code>, to the IPC server running at
|
|
|
* <code>address</code> which is servicing the <code>protocol</code> protocol,
|
|
|
- * with the <code>ticket</code> credentials, returning the value.
|
|
|
+ * with the <code>ticket</code> credentials and <code>rpcTimeout</code> as
|
|
|
+ * timeout, returning the value.
|
|
|
* Throws exceptions if there are network problems or if the remote code
|
|
|
- * threw an exception. */
|
|
|
+ * threw an exception.
|
|
|
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
|
|
+ */
|
|
|
+ @Deprecated
|
|
|
public Writable call(Writable param, InetSocketAddress addr,
|
|
|
Class<?> protocol, UserGroupInformation ticket,
|
|
|
int rpcTimeout)
|
|
|
throws InterruptedException, IOException {
|
|
|
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
|
|
|
+ ticket, rpcTimeout, conf);
|
|
|
+ return call(param, remoteId);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Make a call, passing <code>param</code>, to the IPC server running at
|
|
|
+ * <code>address</code> which is servicing the <code>protocol</code> protocol,
|
|
|
+ * with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
|
|
|
+ * timeout and <code>conf</code> as conf for this connection, returning the
|
|
|
+ * value. Throws exceptions if there are network problems or if the remote
|
|
|
+ * code threw an exception.
|
|
|
+ */
|
|
|
+ public Writable call(Writable param, InetSocketAddress addr,
|
|
|
+ Class<?> protocol, UserGroupInformation ticket,
|
|
|
+ int rpcTimeout, Configuration conf)
|
|
|
+ throws InterruptedException, IOException {
|
|
|
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
|
|
|
+ ticket, rpcTimeout, conf);
|
|
|
+ return call(param, remoteId);
|
|
|
+ }
|
|
|
+
|
|
|
+ /** Make a call, passing <code>param</code>, to the IPC server defined by
|
|
|
+ * <code>remoteId</code>, returning the value.
|
|
|
+ * Throws exceptions if there are network problems or if the remote code
|
|
|
+ * threw an exception. */
|
|
|
+ public Writable call(Writable param, ConnectionId remoteId)
|
|
|
+ throws InterruptedException, IOException {
|
|
|
Call call = new Call(param);
|
|
|
- Connection connection = getConnection(
|
|
|
- addr, protocol, ticket, rpcTimeout, call);
|
|
|
+ Connection connection = getConnection(remoteId, call);
|
|
|
connection.sendParam(call); // send the parameter
|
|
|
boolean interrupted = false;
|
|
|
synchronized (call) {
|
|
@@ -998,7 +1028,7 @@ public class Client {
|
|
|
call.error.fillInStackTrace();
|
|
|
throw call.error;
|
|
|
} else { // local exception
|
|
|
- throw wrapException(addr, call.error);
|
|
|
+ throw wrapException(remoteId.getAddress(), call.error);
|
|
|
}
|
|
|
} else {
|
|
|
return call.value;
|
|
@@ -1038,25 +1068,34 @@ public class Client {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * Makes a set of calls in parallel. Each parameter is sent to the
|
|
|
- * corresponding address. When all values are available, or have timed out
|
|
|
- * or errored, the collected results are returned in an array. The array
|
|
|
- * contains nulls for calls that timed out or errored.
|
|
|
- * @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead
|
|
|
+ * @deprecated Use {@link #call(Writable[], InetSocketAddress[],
|
|
|
+ * Class, UserGroupInformation, Configuration)} instead
|
|
|
*/
|
|
|
@Deprecated
|
|
|
public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
|
|
|
throws IOException, InterruptedException {
|
|
|
- return call(params, addresses, null, null);
|
|
|
+ return call(params, addresses, null, null, conf);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @deprecated Use {@link #call(Writable[], InetSocketAddress[],
|
|
|
+ * Class, UserGroupInformation, Configuration)} instead
|
|
|
+ */
|
|
|
+ @Deprecated
|
|
|
+ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
|
|
+ Class<?> protocol, UserGroupInformation ticket)
|
|
|
+ throws IOException, InterruptedException {
|
|
|
+ return call(params, addresses, protocol, ticket, conf);
|
|
|
}
|
|
|
|
|
|
+
|
|
|
/** Makes a set of calls in parallel. Each parameter is sent to the
|
|
|
* corresponding address. When all values are available, or have timed out
|
|
|
* or errored, the collected results are returned in an array. The array
|
|
|
* contains nulls for calls that timed out or errored. */
|
|
|
- public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
|
|
- Class<?> protocol, UserGroupInformation ticket)
|
|
|
- throws IOException, InterruptedException {
|
|
|
+ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
|
|
+ Class<?> protocol, UserGroupInformation ticket, Configuration conf)
|
|
|
+ throws IOException, InterruptedException {
|
|
|
if (addresses.length == 0) return new Writable[0];
|
|
|
|
|
|
ParallelResults results = new ParallelResults(params.length);
|
|
@@ -1064,8 +1103,9 @@ public class Client {
|
|
|
for (int i = 0; i < params.length; i++) {
|
|
|
ParallelCall call = new ParallelCall(params[i], results, i);
|
|
|
try {
|
|
|
- Connection connection =
|
|
|
- getConnection(addresses[i], protocol, ticket, 0, call);
|
|
|
+ ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
|
|
|
+ protocol, ticket, 0, conf);
|
|
|
+ Connection connection = getConnection(remoteId, call);
|
|
|
connection.sendParam(call); // send each parameter
|
|
|
} catch (IOException e) {
|
|
|
// log errors
|
|
@@ -1084,12 +1124,18 @@ public class Client {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // for unit testing only
|
|
|
+ @InterfaceAudience.Private
|
|
|
+ @InterfaceStability.Unstable
|
|
|
+ Set<ConnectionId> getConnectionIds() {
|
|
|
+ synchronized (connections) {
|
|
|
+ return connections.keySet();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/** Get a connection from the pool, or create a new one and add it to the
|
|
|
- * pool. Connections to a given host/port are reused. */
|
|
|
- private Connection getConnection(InetSocketAddress addr,
|
|
|
- Class<?> protocol,
|
|
|
- UserGroupInformation ticket,
|
|
|
- int rpcTimeout,
|
|
|
+ * pool. Connections to a given ConnectionId are reused. */
|
|
|
+ private Connection getConnection(ConnectionId remoteId,
|
|
|
Call call)
|
|
|
throws IOException, InterruptedException {
|
|
|
if (!running.get()) {
|
|
@@ -1101,8 +1147,6 @@ public class Client {
|
|
|
* connectionsId object and with set() method. We need to manage the
|
|
|
* refs for keys in HashMap properly. For now its ok.
|
|
|
*/
|
|
|
- ConnectionId remoteId = new ConnectionId(
|
|
|
- addr, protocol, ticket, rpcTimeout);
|
|
|
do {
|
|
|
synchronized (connections) {
|
|
|
connection = connections.get(remoteId);
|
|
@@ -1120,24 +1164,40 @@ public class Client {
|
|
|
connection.setupIOstreams();
|
|
|
return connection;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
/**
|
|
|
* This class holds the address and the user ticket. The client connections
|
|
|
* to servers are uniquely identified by <remoteAddress, protocol, ticket>
|
|
|
*/
|
|
|
- private static class ConnectionId {
|
|
|
+ static class ConnectionId {
|
|
|
InetSocketAddress address;
|
|
|
UserGroupInformation ticket;
|
|
|
Class<?> protocol;
|
|
|
private static final int PRIME = 16777619;
|
|
|
private int rpcTimeout;
|
|
|
+ private String serverPrincipal;
|
|
|
+ private int maxIdleTime; //connections will be culled if it was idle for
|
|
|
+ //maxIdleTime msecs
|
|
|
+ private int maxRetries; //the max. no. of retries for socket connections
|
|
|
+ private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
|
|
+ private boolean doPing; //do we need to send ping message
|
|
|
+ private int pingInterval; // how often sends ping to the server in msecs
|
|
|
|
|
|
ConnectionId(InetSocketAddress address, Class<?> protocol,
|
|
|
- UserGroupInformation ticket, int rpcTimeout) {
|
|
|
+ UserGroupInformation ticket, int rpcTimeout,
|
|
|
+ String serverPrincipal, int maxIdleTime,
|
|
|
+ int maxRetries, boolean tcpNoDelay,
|
|
|
+ boolean doPing, int pingInterval) {
|
|
|
this.protocol = protocol;
|
|
|
this.address = address;
|
|
|
this.ticket = ticket;
|
|
|
this.rpcTimeout = rpcTimeout;
|
|
|
+ this.serverPrincipal = serverPrincipal;
|
|
|
+ this.maxIdleTime = maxIdleTime;
|
|
|
+ this.maxRetries = maxRetries;
|
|
|
+ this.tcpNoDelay = tcpNoDelay;
|
|
|
+ this.doPing = doPing;
|
|
|
+ this.pingInterval = pingInterval;
|
|
|
}
|
|
|
|
|
|
InetSocketAddress getAddress() {
|
|
@@ -1156,25 +1216,102 @@ public class Client {
|
|
|
return rpcTimeout;
|
|
|
}
|
|
|
|
|
|
+ String getServerPrincipal() {
|
|
|
+ return serverPrincipal;
|
|
|
+ }
|
|
|
+
|
|
|
+ int getMaxIdleTime() {
|
|
|
+ return maxIdleTime;
|
|
|
+ }
|
|
|
+
|
|
|
+ int getMaxRetries() {
|
|
|
+ return maxRetries;
|
|
|
+ }
|
|
|
+
|
|
|
+ boolean getTcpNoDelay() {
|
|
|
+ return tcpNoDelay;
|
|
|
+ }
|
|
|
+
|
|
|
+ boolean getDoPing() {
|
|
|
+ return doPing;
|
|
|
+ }
|
|
|
+
|
|
|
+ int getPingInterval() {
|
|
|
+ return pingInterval;
|
|
|
+ }
|
|
|
+
|
|
|
+ static ConnectionId getConnectionId(InetSocketAddress addr,
|
|
|
+ Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
|
|
|
+ Configuration conf) throws IOException {
|
|
|
+ String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
|
|
|
+ return new ConnectionId(addr, protocol, ticket,
|
|
|
+ rpcTimeout, remotePrincipal,
|
|
|
+ conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
|
|
|
+ conf.getInt("ipc.client.connect.max.retries", 10),
|
|
|
+ conf.getBoolean("ipc.client.tcpnodelay", false),
|
|
|
+ conf.getBoolean("ipc.client.ping", true),
|
|
|
+ Client.getPingInterval(conf));
|
|
|
+ }
|
|
|
+
|
|
|
+ private static String getRemotePrincipal(Configuration conf,
|
|
|
+ InetSocketAddress address, Class<?> protocol) throws IOException {
|
|
|
+ if (protocol == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
|
|
|
+ if (krbInfo != null) {
|
|
|
+ String serverKey = krbInfo.serverPrincipal();
|
|
|
+ if (serverKey == null) {
|
|
|
+ throw new IOException(
|
|
|
+ "Can't obtain server Kerberos config key from protocol="
|
|
|
+ + protocol.getCanonicalName());
|
|
|
+ }
|
|
|
+ return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
|
|
|
+ .getAddress().getCanonicalHostName());
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ static boolean isEqual(Object a, Object b) {
|
|
|
+ return a == null ? b == null : a.equals(b);
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public boolean equals(Object obj) {
|
|
|
- if (obj instanceof ConnectionId) {
|
|
|
- ConnectionId id = (ConnectionId) obj;
|
|
|
- return address.equals(id.address) && protocol == id.protocol &&
|
|
|
- ((ticket != null && ticket.equals(id.ticket)) ||
|
|
|
- (ticket == id.ticket)) && rpcTimeout == id.rpcTimeout;
|
|
|
- }
|
|
|
- return false;
|
|
|
+ if (obj == this) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (obj instanceof ConnectionId) {
|
|
|
+ ConnectionId that = (ConnectionId) obj;
|
|
|
+ return isEqual(this.address, that.address)
|
|
|
+ && this.doPing == that.doPing
|
|
|
+ && this.maxIdleTime == that.maxIdleTime
|
|
|
+ && this.maxRetries == that.maxRetries
|
|
|
+ && this.pingInterval == that.pingInterval
|
|
|
+ && isEqual(this.protocol, that.protocol)
|
|
|
+ && this.rpcTimeout == that.rpcTimeout
|
|
|
+ && isEqual(this.serverPrincipal, that.serverPrincipal)
|
|
|
+ && this.tcpNoDelay == that.tcpNoDelay
|
|
|
+ && isEqual(this.ticket, that.ticket);
|
|
|
+ }
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
- @Override // simply use the default Object#hashcode() ?
|
|
|
+ @Override
|
|
|
public int hashCode() {
|
|
|
- return (address.hashCode() + PRIME * (
|
|
|
- PRIME * (
|
|
|
- PRIME * System.identityHashCode(protocol) ^
|
|
|
- System.identityHashCode(ticket)
|
|
|
- ) ^ System.identityHashCode(rpcTimeout)
|
|
|
- ));
|
|
|
+ int result = 1;
|
|
|
+ result = PRIME * result + ((address == null) ? 0 : address.hashCode());
|
|
|
+ result = PRIME * result + (doPing ? 1231 : 1237);
|
|
|
+ result = PRIME * result + maxIdleTime;
|
|
|
+ result = PRIME * result + maxRetries;
|
|
|
+ result = PRIME * result + pingInterval;
|
|
|
+ result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
|
|
|
+ result = PRIME * result + rpcTimeout;
|
|
|
+ result = PRIME * result
|
|
|
+ + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
|
|
|
+ result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
|
|
|
+ result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
|
|
|
+ return result;
|
|
|
}
|
|
|
}
|
|
|
}
|