소스 검색

HADOOP-14034. Allow ipc layer exceptions to selectively close connections. Contributed by Daryn Sharp.

Kihwal Lee 8 년 전
부모
커밋
d008b55153

+ 106 - 103
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java

@@ -1240,20 +1240,16 @@ public abstract class Server {
         LOG.info(Thread.currentThread().getName() + ": readAndProcess caught InterruptedException", ieo);
         throw ieo;
       } catch (Exception e) {
-        // Do not log WrappedRpcServerExceptionSuppressed.
-        if (!(e instanceof WrappedRpcServerExceptionSuppressed)) {
-          // A WrappedRpcServerException is an exception that has been sent
-          // to the client, so the stacktrace is unnecessary; any other
-          // exceptions are unexpected internal server errors and thus the
-          // stacktrace should be logged.
-          LOG.info(Thread.currentThread().getName() +
-              ": readAndProcess from client " + c.getHostAddress() +
-              " threw exception [" + e + "]",
-              (e instanceof WrappedRpcServerException) ? null : e);
-        }
+        // Any exceptions that reach here are fatal unexpected internal errors
+        // that could not be sent to the client.
+        LOG.info(Thread.currentThread().getName() +
+            ": readAndProcess from client " + c +
+            " threw exception [" + e + "]", e);
         count = -1; //so that the (count < 0) block is executed
       }
-      if (count < 0) {
+      // setupResponse will signal the connection should be closed when a
+      // fatal response is sent.
+      if (count < 0 || c.shouldClose()) {
         closeConnection(c);
         c = null;
       }
@@ -1581,16 +1577,20 @@ public abstract class Server {
    * unnecessary stack trace logging if it's not an internal server error. 
    */
   @SuppressWarnings("serial")
-  private static class WrappedRpcServerException extends RpcServerException {
+  private static class FatalRpcServerException extends RpcServerException {
     private final RpcErrorCodeProto errCode;
-    public WrappedRpcServerException(RpcErrorCodeProto errCode, IOException ioe) {
+    public FatalRpcServerException(RpcErrorCodeProto errCode, IOException ioe) {
       super(ioe.toString(), ioe);
       this.errCode = errCode;
     }
-    public WrappedRpcServerException(RpcErrorCodeProto errCode, String message) {
+    public FatalRpcServerException(RpcErrorCodeProto errCode, String message) {
       this(errCode, new RpcServerException(message));
     }
     @Override
+    public RpcStatusProto getRpcStatusProto() {
+      return RpcStatusProto.FATAL;
+    }
+    @Override
     public RpcErrorCodeProto getRpcErrorCodeProto() {
       return errCode;
     }
@@ -1600,19 +1600,6 @@ public abstract class Server {
     }
   }
 
-  /**
-   * A WrappedRpcServerException that is suppressed altogether
-   * for the purposes of logging.
-   */
-  @SuppressWarnings("serial")
-  private static class WrappedRpcServerExceptionSuppressed
-      extends WrappedRpcServerException {
-    public WrappedRpcServerExceptionSuppressed(
-        RpcErrorCodeProto errCode, IOException ioe) {
-      super(errCode, ioe);
-    }
-  }
-
   /** Reads calls from a connection and queues them for handling. */
   public class Connection {
     private boolean connectionHeaderRead = false; // connection  header is read?
@@ -1644,7 +1631,8 @@ public abstract class Server {
     private ByteBuffer unwrappedData;
     private ByteBuffer unwrappedDataLengthBuffer;
     private int serviceClass;
-    
+    private boolean shouldClose = false;
+
     UserGroupInformation user = null;
     public UserGroupInformation attemptingUser = null; // user name before auth
 
@@ -1685,7 +1673,15 @@ public abstract class Server {
     public String toString() {
       return getHostAddress() + ":" + remotePort; 
     }
-    
+
+    boolean setShouldClose() {
+      return shouldClose = true;
+    }
+
+    boolean shouldClose() {
+      return shouldClose;
+    }
+
     public String getHostAddress() {
       return hostAddress;
     }
@@ -1739,13 +1735,13 @@ public abstract class Server {
     }
 
     private void saslReadAndProcess(RpcWritable.Buffer buffer) throws
-    WrappedRpcServerException, IOException, InterruptedException {
+        RpcServerException, IOException, InterruptedException {
       final RpcSaslProto saslMessage =
           getMessage(RpcSaslProto.getDefaultInstance(), buffer);
       switch (saslMessage.getState()) {
         case WRAP: {
           if (!saslContextEstablished || !useWrap) {
-            throw new WrappedRpcServerException(
+            throw new FatalRpcServerException(
                 RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
                 new SaslException("Server is not wrapping data"));
           }
@@ -1780,11 +1776,11 @@ public abstract class Server {
       }
       return e;
     }
-    
+
     private void saslProcess(RpcSaslProto saslMessage)
-        throws WrappedRpcServerException, IOException, InterruptedException {
+        throws RpcServerException, IOException, InterruptedException {
       if (saslContextEstablished) {
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
             new SaslException("Negotiation is already complete"));
       }
@@ -1818,10 +1814,10 @@ public abstract class Server {
           AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user);
           saslContextEstablished = true;
         }
-      } catch (WrappedRpcServerException wrse) { // don't re-wrap
-        throw wrse;
+      } catch (RpcServerException rse) { // don't re-wrap
+        throw rse;
       } catch (IOException ioe) {
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_UNAUTHORIZED, ioe);
       }
       // send back response if any, may throw IOException
@@ -1939,14 +1935,14 @@ public abstract class Server {
       setupResponse(saslCall,
           RpcStatusProto.SUCCESS, null,
           RpcWritable.wrap(message), null, null);
-      saslCall.sendResponse();
+      sendResponse(saslCall);
     }
 
     private void doSaslReply(Exception ioe) throws IOException {
       setupResponse(authFailedCall,
           RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
           null, ioe.getClass().getName(), ioe.getLocalizedMessage());
-      authFailedCall.sendResponse();
+      sendResponse(authFailedCall);
     }
 
     private void disposeSasl() {
@@ -1975,12 +1971,8 @@ public abstract class Server {
       }
     }
 
-    public int readAndProcess()
-        throws WrappedRpcServerException, IOException, InterruptedException {
-      while (true) {
-        /* Read at most one RPC. If the header is not read completely yet
-         * then iterate until we read first RPC or until there is no data left.
-         */    
+    public int readAndProcess() throws IOException, InterruptedException {
+      while (!shouldClose()) { // stop if a fatal response has been sent.
         int count = -1;
         if (dataLengthBuffer.remaining() > 0) {
           count = channelRead(channel, dataLengthBuffer);       
@@ -2042,15 +2034,17 @@ public abstract class Server {
         if (data.remaining() == 0) {
           dataLengthBuffer.clear();
           data.flip();
+          ByteBuffer requestData = data;
+          data = null; // null out in case processOneRpc throws.
           boolean isHeaderRead = connectionContextRead;
-          processOneRpc(data);
-          data = null;
+          processOneRpc(requestData);
           if (!isHeaderRead) {
             continue;
           }
         } 
         return count;
       }
+      return -1;
     }
 
     private AuthProtocol initializeAuthContext(int authType)
@@ -2125,14 +2119,14 @@ public abstract class Server {
         setupResponse(fakeCall,
             RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH,
             null, VersionMismatch.class.getName(), errMsg);
-        fakeCall.sendResponse();
+        sendResponse(fakeCall);
       } else if (clientVersion >= 3) {
         RpcCall fakeCall = new RpcCall(this, -1);
         // Versions 3 to 8 use older response
         setupResponseOldVersionFatal(buffer, fakeCall,
             null, VersionMismatch.class.getName(), errMsg);
 
-        fakeCall.sendResponse();
+        sendResponse(fakeCall);
       } else if (clientVersion == 2) { // Hadoop 0.18.3
         RpcCall fakeCall = new RpcCall(this, 0);
         DataOutputStream out = new DataOutputStream(buffer);
@@ -2141,7 +2135,7 @@ public abstract class Server {
         WritableUtils.writeString(out, VersionMismatch.class.getName());
         WritableUtils.writeString(out, errMsg);
         fakeCall.setResponse(ByteBuffer.wrap(buffer.toByteArray()));
-        fakeCall.sendResponse();
+        sendResponse(fakeCall);
       }
     }
     
@@ -2149,19 +2143,19 @@ public abstract class Server {
       RpcCall fakeCall = new RpcCall(this, 0);
       fakeCall.setResponse(ByteBuffer.wrap(
           RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8)));
-      fakeCall.sendResponse();
+      sendResponse(fakeCall);
     }
 
     /** Reads the connection context following the connection header
      * @param buffer - DataInputStream from which to read the header
-     * @throws WrappedRpcServerException - if the header cannot be
+     * @throws RpcServerException - if the header cannot be
      *         deserialized, or the user is not authorized
      */ 
     private void processConnectionContext(RpcWritable.Buffer buffer)
-        throws WrappedRpcServerException {
+        throws RpcServerException {
       // allow only one connection context during a session
       if (connectionContextRead) {
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
             "Connection context already processed");
       }
@@ -2182,7 +2176,7 @@ public abstract class Server {
             && (!protocolUser.getUserName().equals(user.getUserName()))) {
           if (authMethod == AuthMethod.TOKEN) {
             // Not allowed to doAs if token authentication is used
-            throw new WrappedRpcServerException(
+            throw new FatalRpcServerException(
                 RpcErrorCodeProto.FATAL_UNAUTHORIZED,
                 new AccessControlException("Authenticated user (" + user
                     + ") doesn't match what the client claims to be ("
@@ -2213,7 +2207,7 @@ public abstract class Server {
      * @throws InterruptedException
      */    
     private void unwrapPacketAndProcessRpcs(byte[] inBuf)
-        throws WrappedRpcServerException, IOException, InterruptedException {
+        throws IOException, InterruptedException {
       if (LOG.isDebugEnabled()) {
         LOG.debug("Have read input token of size " + inBuf.length
             + " for processing by saslServer.unwrap()");
@@ -2222,7 +2216,7 @@ public abstract class Server {
       ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
           inBuf));
       // Read all RPCs contained in the inBuf, even partial ones
-      while (true) {
+      while (!shouldClose()) { // stop if a fatal response has been sent.
         int count = -1;
         if (unwrappedDataLengthBuffer.remaining() > 0) {
           count = channelRead(ch, unwrappedDataLengthBuffer);
@@ -2243,17 +2237,14 @@ public abstract class Server {
         if (unwrappedData.remaining() == 0) {
           unwrappedDataLengthBuffer.clear();
           unwrappedData.flip();
-          processOneRpc(unwrappedData);
-          unwrappedData = null;
+          ByteBuffer requestData = unwrappedData;
+          unwrappedData = null; // null out in case processOneRpc throws.
+          processOneRpc(requestData);
         }
       }
     }
     
     /**
-<<<<<<< HEAD
-     * Process an RPC Request - handle connection setup and decoding of
-     * request into a Call
-=======
      * Process one RPC Request from buffer read from socket stream 
      *  - decode rpc in a rpc-Call
      *  - handle out-of-band RPC requests such as the initial connectionContext
@@ -2264,17 +2255,16 @@ public abstract class Server {
      * if SASL then SASL has been established and the buf we are passed
      * has been unwrapped from SASL.
      * 
->>>>>>> 3d94da1... HADOOP-11552. Allow handoff on the server side for RPC requests. Contributed by Siddharth Seth
      * @param bb - contains the RPC request header and the rpc request
      * @throws IOException - internal error that should not be returned to
      *         client, typically failure to respond to client
-     * @throws WrappedRpcServerException - an exception to be sent back to
-     *         the client that does not require verbose logging by the
-     *         Listener thread
      * @throws InterruptedException
      */
     private void processOneRpc(ByteBuffer bb)
-        throws IOException, WrappedRpcServerException, InterruptedException {
+        throws IOException, InterruptedException {
+      // exceptions that escape this method are fatal to the connection.
+      // setupResponse will use the rpc status to determine if the connection
+      // should be closed.
       int callId = -1;
       int retry = RpcConstants.INVALID_RETRY_COUNT;
       try {
@@ -2291,40 +2281,47 @@ public abstract class Server {
         if (callId < 0) { // callIds typically used during connection setup
           processRpcOutOfBandRequest(header, buffer);
         } else if (!connectionContextRead) {
-          throw new WrappedRpcServerException(
+          throw new FatalRpcServerException(
               RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
               "Connection context not established");
         } else {
           processRpcRequest(header, buffer);
         }
-      } catch (WrappedRpcServerException wrse) { // inform client of error
-        Throwable ioe = wrse.getCause();
+      } catch (RpcServerException rse) {
+        // inform client of error, but do not rethrow else non-fatal
+        // exceptions will close connection!
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(Thread.currentThread().getName() +
+              ": processOneRpc from client " + this +
+              " threw exception [" + rse + "]");
+        }
+        // use the wrapped exception if there is one.
+        Throwable t = (rse.getCause() != null) ? rse.getCause() : rse;
         final RpcCall call = new RpcCall(this, callId, retry);
         setupResponse(call,
-            RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null,
-            ioe.getClass().getName(), ioe.getMessage());
-        call.sendResponse();
-        throw wrse;
+            rse.getRpcStatusProto(), rse.getRpcErrorCodeProto(), null,
+            t.getClass().getName(), t.getMessage());
+        sendResponse(call);
       }
     }
 
     /**
      * Verify RPC header is valid
      * @param header - RPC request header
-     * @throws WrappedRpcServerException - header contains invalid values 
+     * @throws RpcServerException - header contains invalid values
      */
     private void checkRpcHeaders(RpcRequestHeaderProto header)
-        throws WrappedRpcServerException {
+        throws RpcServerException {
       if (!header.hasRpcOp()) {
         String err = " IPC Server: No rpc op in rpcRequestHeader";
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
       }
       if (header.getRpcOp() != 
           RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) {
         String err = "IPC Server does not implement rpc header operation" + 
                 header.getRpcOp();
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
       }
       // If we know the rpc kind, get its class so that we can deserialize
@@ -2332,7 +2329,7 @@ public abstract class Server {
       // we continue with this original design.
       if (!header.hasRpcKind()) {
         String err = " IPC Server: No rpc kind in rpcRequestHeader";
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
       }
     }
@@ -2342,13 +2339,13 @@ public abstract class Server {
      * have been already read
      * @param header - RPC request header
      * @param buffer - stream to request payload
-     * @throws WrappedRpcServerException - due to fatal rpc layer issues such
-     *   as invalid header or deserialization error. In this case a RPC fatal
-     *   status response will later be sent back to client.
+     * @throws RpcServerException - generally due to fatal rpc layer issues
+     *   such as invalid header or deserialization error.  The call queue
+     *   may also throw a fatal or non-fatal exception on overflow.
      * @throws InterruptedException
      */
     private void processRpcRequest(RpcRequestHeaderProto header,
-        RpcWritable.Buffer buffer) throws WrappedRpcServerException,
+        RpcWritable.Buffer buffer) throws RpcServerException,
         InterruptedException {
       Class<? extends Writable> rpcRequestClass = 
           getRpcRequestWrapper(header.getRpcKind());
@@ -2357,18 +2354,20 @@ public abstract class Server {
             " from client " + getHostAddress());
         final String err = "Unknown rpc kind in rpc header"  + 
             header.getRpcKind();
-        throw new WrappedRpcServerException(
-            RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);   
+        throw new FatalRpcServerException(
+            RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
       }
       Writable rpcRequest;
       try { //Read the rpc request
         rpcRequest = buffer.newInstance(rpcRequestClass, conf);
+      } catch (RpcServerException rse) { // lets tests inject failures.
+        throw rse;
       } catch (Throwable t) { // includes runtime exception from newInstance
         LOG.warn("Unable to read call parameters for client " +
                  getHostAddress() + "on connection protocol " +
             this.protocolName + " for rpcKind " + header.getRpcKind(),  t);
         String err = "IPC server unable to read call parameters: "+ t.getMessage();
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, err);
       }
         
@@ -2407,7 +2406,7 @@ public abstract class Server {
       try {
         queueCall(call);
       } catch (IOException ioe) {
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.ERROR_RPC_SERVER, ioe);
       }
       incRpcCount();  // Increment the rpc count
@@ -2418,20 +2417,20 @@ public abstract class Server {
      * reading and authorizing the connection header
      * @param header - RPC header
      * @param buffer - stream to request payload
-     * @throws WrappedRpcServerException - setup failed due to SASL
+     * @throws RpcServerException - setup failed due to SASL
      *         negotiation failure, premature or invalid connection context,
      *         or other state errors 
      * @throws IOException - failed to send a response back to the client
      * @throws InterruptedException
      */
     private void processRpcOutOfBandRequest(RpcRequestHeaderProto header,
-        RpcWritable.Buffer buffer) throws WrappedRpcServerException,
+        RpcWritable.Buffer buffer) throws RpcServerException,
             IOException, InterruptedException {
       final int callId = header.getCallId();
       if (callId == CONNECTION_CONTEXT_CALL_ID) {
         // SASL must be established prior to connection context
         if (authProtocol == AuthProtocol.SASL && !saslContextEstablished) {
-          throw new WrappedRpcServerException(
+          throw new FatalRpcServerException(
               RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
               "Connection header sent during SASL negotiation");
         }
@@ -2440,7 +2439,7 @@ public abstract class Server {
       } else if (callId == AuthProtocol.SASL.callId) {
         // if client was switched to simple, ignore first SASL message
         if (authProtocol != AuthProtocol.SASL) {
-          throw new WrappedRpcServerException(
+          throw new FatalRpcServerException(
               RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
               "SASL protocol not requested by client");
         }
@@ -2448,7 +2447,7 @@ public abstract class Server {
       } else if (callId == PING_CALL_ID) {
         LOG.debug("Received ping message");
       } else {
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
             "Unknown out of band call #" + callId);
       }
@@ -2456,9 +2455,9 @@ public abstract class Server {
 
     /**
      * Authorize proxy users to access this server
-     * @throws WrappedRpcServerException - user is not allowed to proxy
+     * @throws RpcServerException - user is not allowed to proxy
      */
-    private void authorizeConnection() throws WrappedRpcServerException {
+    private void authorizeConnection() throws RpcServerException {
       try {
         // If auth method is TOKEN, the token was obtained by the
         // real user for the effective user, therefore not required to
@@ -2478,34 +2477,34 @@ public abstract class Server {
             + " for protocol " + connectionContext.getProtocol()
             + " is unauthorized for user " + user);
         rpcMetrics.incrAuthorizationFailures();
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae);
       }
     }
     
     /**
      * Decode the a protobuf from the given input stream 
-<<<<<<< HEAD
      * @param message - Representation of the type of message
      * @param buffer - a buffer to read the protobuf
-=======
->>>>>>> 3d94da1... HADOOP-11552. Allow handoff on the server side for RPC requests. Contributed by Siddharth Seth
      * @return Message - decoded protobuf
-     * @throws WrappedRpcServerException - deserialization failed
+     * @throws RpcServerException - deserialization failed
      */
     @SuppressWarnings("unchecked")
     <T extends Message> T getMessage(Message message,
-        RpcWritable.Buffer buffer) throws WrappedRpcServerException {
+        RpcWritable.Buffer buffer) throws RpcServerException {
       try {
         return (T)buffer.getValue(message);
       } catch (Exception ioe) {
         Class<?> protoClass = message.getClass();
-        throw new WrappedRpcServerException(
+        throw new FatalRpcServerException(
             RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST,
             "Error decoding " + protoClass.getSimpleName() + ": "+ ioe);
       }
     }
 
+    // ipc reader threads should invoke this directly, whereas handlers
+    // must invoke call.sendResponse to allow lifecycle management of
+    // external, postponed, deferred calls, etc.
     private void sendResponse(RpcCall call) throws IOException {
       responder.doRespond(call);
     }
@@ -2808,6 +2807,10 @@ public abstract class Server {
       RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode,
       Writable rv, String errorClass, String error)
           throws IOException {
+    // fatal responses will cause the reader to close the connection.
+    if (status == RpcStatusProto.FATAL) {
+      call.connection.setShouldClose();
+    }
     RpcResponseHeaderProto.Builder headerBuilder =
         RpcResponseHeaderProto.newBuilder();
     headerBuilder.setClientId(ByteString.copyFrom(call.clientId));

+ 119 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java

@@ -31,9 +31,12 @@ import org.apache.hadoop.io.retry.RetryPolicy;
 import org.apache.hadoop.io.retry.RetryProxy;
 import org.apache.hadoop.ipc.Client.ConnectionId;
 import org.apache.hadoop.ipc.Server.Call;
+import org.apache.hadoop.ipc.Server.Connection;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto;
+import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
 import org.apache.hadoop.ipc.protobuf.TestProtos;
 import org.apache.hadoop.metrics2.MetricsRecordBuilder;
+import org.apache.hadoop.metrics2.lib.MutableCounterLong;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.AccessControlException;
 import org.apache.hadoop.security.SecurityUtil;
@@ -64,6 +67,7 @@ import java.net.ConnectException;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
 import java.security.PrivilegedAction;
 import java.security.PrivilegedExceptionAction;
 import java.util.ArrayList;
@@ -77,6 +81,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -85,6 +90,10 @@ import static org.apache.hadoop.test.MetricsAsserts.assertCounterGt;
 import static org.apache.hadoop.test.MetricsAsserts.getLongCounter;
 import static org.apache.hadoop.test.MetricsAsserts.getMetrics;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.spy;
@@ -1365,6 +1374,116 @@ public class TestRPC extends TestRpcBase {
     }
   }
 
+  public static class FakeRequestClass extends RpcWritable {
+    static volatile IOException exception;
+    @Override
+    void writeTo(ResponseBuffer out) throws IOException {
+      throw new UnsupportedOperationException();
+    }
+    @Override
+    <T> T readFrom(ByteBuffer bb) throws IOException {
+      throw exception;
+    }
+  }
+
+  @SuppressWarnings("serial")
+  public static class TestReaderException extends IOException {
+    public TestReaderException(String msg) {
+      super(msg);
+    }
+    @Override
+    public boolean equals(Object t) {
+      return (t.getClass() == TestReaderException.class) &&
+             getMessage().equals(((TestReaderException)t).getMessage());
+    }
+  }
+
+  @Test (timeout=30000)
+  public void testReaderExceptions() throws Exception {
+    Server server = null;
+    TestRpcService proxy = null;
+
+    // will attempt to return this exception from a reader with and w/o
+    // the connection closing.
+    IOException expectedIOE = new TestReaderException("testing123");
+
+    @SuppressWarnings("serial")
+    IOException rseError = new RpcServerException("keepalive", expectedIOE){
+      @Override
+      public RpcStatusProto getRpcStatusProto() {
+        return RpcStatusProto.ERROR;
+      }
+    };
+    @SuppressWarnings("serial")
+    IOException rseFatal = new RpcServerException("disconnect", expectedIOE) {
+      @Override
+      public RpcStatusProto getRpcStatusProto() {
+        return RpcStatusProto.FATAL;
+      }
+    };
+
+    try {
+      RPC.Builder builder = newServerBuilder(conf)
+          .setQueueSizePerHandler(1).setNumHandlers(1).setVerbose(true);
+      server = setupTestServer(builder);
+      Whitebox.setInternalState(
+          server, "rpcRequestClass", FakeRequestClass.class);
+      MutableCounterLong authMetric =
+          (MutableCounterLong)Whitebox.getInternalState(
+              server.getRpcMetrics(), "rpcAuthorizationSuccesses");
+
+      proxy = getClient(addr, conf);
+      boolean isDisconnected = true;
+      Connection lastConn = null;
+      long expectedAuths = 0;
+
+      // fuzz the client.
+      for (int i=0; i < 128; i++) {
+        String reqName = "request[" + i + "]";
+        int r = ThreadLocalRandom.current().nextInt();
+        final boolean doDisconnect = r % 4 == 0;
+        LOG.info("TestDisconnect request[" + i + "] " +
+                 " shouldConnect=" + isDisconnected +
+                 " willDisconnect=" + doDisconnect);
+        if (isDisconnected) {
+          expectedAuths++;
+        }
+        try {
+          FakeRequestClass.exception = doDisconnect ? rseFatal : rseError;
+          proxy.ping(null, newEmptyRequest());
+          fail(reqName + " didn't fail");
+        } catch (ServiceException e) {
+          RemoteException re = (RemoteException)e.getCause();
+          assertEquals(reqName, expectedIOE, re.unwrapRemoteException());
+        }
+        // check authorizations to ensure new connection when expected,
+        // then conclusively determine if connections are disconnected
+        // correctly.
+        assertEquals(reqName, expectedAuths, authMetric.value());
+        if (!doDisconnect) {
+          // if it wasn't fatal, verify there's only one open connection.
+          Connection[] conns = server.getConnections();
+          assertEquals(reqName, 1, conns.length);
+          // verify whether the connection should have been reused.
+          if (isDisconnected) {
+            assertNotSame(reqName, lastConn, conns[0]);
+          } else {
+            assertSame(reqName, lastConn, conns[0]);
+          }
+          lastConn = conns[0];
+        } else if (lastConn != null) {
+          // avoid race condition in server where connection may not be
+          // fully removed yet.  just make sure it's marked for being closed.
+          // the open connection checks above ensure correct behavior.
+          assertTrue(reqName, lastConn.shouldClose());
+        }
+        isDisconnected = doDisconnect;
+      }
+    } finally {
+      stop(server, proxy);
+    }
+  }
+
   public static void main(String[] args) throws Exception {
     new TestRPC().testCallsInternal(conf);
   }