Explorar o código

HADOOP-11552. Allow handoff on the server side for RPC requests. Contributed by Siddharth Seth

(cherry picked from commit 3d94da1e00fc6238fad458e415219f87920f1fc3)
Jian He %!s(int64=8) %!d(string=hai) anos
pai
achega
accd9136e4

+ 71 - 4
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java

@@ -31,7 +31,6 @@ import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.retry.RetryPolicy;
 import org.apache.hadoop.ipc.Client.ConnectionId;
 import org.apache.hadoop.ipc.RPC.RpcInvoker;
-import org.apache.hadoop.ipc.RpcWritable;
 import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.token.SecretManager;
@@ -344,6 +343,60 @@ public class ProtobufRpcEngine implements RpcEngine {
   }
   
   public static class Server extends RPC.Server {
+
+    static final ThreadLocal<ProtobufRpcEngineCallback> currentCallback =
+        new ThreadLocal<>();
+
+    static final ThreadLocal<CallInfo> currentCallInfo = new ThreadLocal<>();
+
+    static class CallInfo {
+      private final RPC.Server server;
+      private final String methodName;
+
+      public CallInfo(RPC.Server server, String methodName) {
+        this.server = server;
+        this.methodName = methodName;
+      }
+    }
+
+    static class ProtobufRpcEngineCallbackImpl
+        implements ProtobufRpcEngineCallback {
+
+      private final RPC.Server server;
+      private final Call call;
+      private final String methodName;
+      private final long setupTime;
+
+      public ProtobufRpcEngineCallbackImpl() {
+        this.server = currentCallInfo.get().server;
+        this.call = Server.getCurCall().get();
+        this.methodName = currentCallInfo.get().methodName;
+        this.setupTime = Time.now();
+      }
+
+      @Override
+      public void setResponse(Message message) {
+        long processingTime = Time.now() - setupTime;
+        call.setDeferredResponse(RpcWritable.wrap(message));
+        server.updateDeferredMetrics(methodName, processingTime);
+      }
+
+      @Override
+      public void error(Throwable t) {
+        long processingTime = Time.now() - setupTime;
+        String detailedMetricsName = t.getClass().getSimpleName();
+        server.updateDeferredMetrics(detailedMetricsName, processingTime);
+        call.setDeferredError(t);
+      }
+    }
+
+    @InterfaceStability.Unstable
+    public static ProtobufRpcEngineCallback registerForDeferredResponse() {
+      ProtobufRpcEngineCallback callback = new ProtobufRpcEngineCallbackImpl();
+      currentCallback.set(callback);
+      return callback;
+    }
+
     /**
      * Construct an RPC server.
      * 
@@ -442,9 +495,19 @@ public class ProtobufRpcEngine implements RpcEngine {
         long startTime = Time.now();
         int qTime = (int) (startTime - receiveTime);
         Exception exception = null;
+        boolean isDeferred = false;
         try {
           server.rpcDetailedMetrics.init(protocolImpl.protocolClass);
+          currentCallInfo.set(new CallInfo(server, methodName));
           result = service.callBlockingMethod(methodDescriptor, null, param);
+          // Check if this needs to be a deferred response,
+          // by checking the ThreadLocal callback being set
+          if (currentCallback.get() != null) {
+            Server.getCurCall().get().deferResponse();
+            isDeferred = true;
+            currentCallback.set(null);
+            return null;
+          }
         } catch (ServiceException e) {
           exception = (Exception) e.getCause();
           throw (Exception) e.getCause();
@@ -452,10 +515,13 @@ public class ProtobufRpcEngine implements RpcEngine {
           exception = e;
           throw e;
         } finally {
+          currentCallInfo.set(null);
           int processingTime = (int) (Time.now() - startTime);
           if (LOG.isDebugEnabled()) {
-            String msg = "Served: " + methodName + " queueTime= " + qTime +
-                " procesingTime= " + processingTime;
+            String msg =
+                "Served: " + methodName + (isDeferred ? ", deferred" : "") +
+                    ", queueTime= " + qTime +
+                    " procesingTime= " + processingTime;
             if (exception != null) {
               msg += " exception= " + exception.getClass().getSimpleName();
             }
@@ -464,7 +530,8 @@ public class ProtobufRpcEngine implements RpcEngine {
           String detailedMetricsName = (exception == null) ?
               methodName :
               exception.getClass().getSimpleName();
-          server.updateMetrics(detailedMetricsName, qTime, processingTime);
+          server.updateMetrics(detailedMetricsName, qTime, processingTime,
+              isDeferred);
         }
         return RpcWritable.wrap(result);
       }

+ 29 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngineCallback.java

@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import com.google.protobuf.Message;
+
+public interface ProtobufRpcEngineCallback {
+
+  public void setResponse(Message message);
+
+  public void error(Throwable t);
+
+}

+ 168 - 36
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java

@@ -497,18 +497,25 @@ public abstract class Server {
     }
   }
 
-  void updateMetrics(String name, int queueTime, int processingTime) {
+  void updateMetrics(String name, int queueTime, int processingTime,
+                     boolean deferredCall) {
     rpcMetrics.addRpcQueueTime(queueTime);
-    rpcMetrics.addRpcProcessingTime(processingTime);
-    rpcDetailedMetrics.addProcessingTime(name, processingTime);
-    callQueue.addResponseTime(name, getPriorityLevel(), queueTime,
-        processingTime);
-
-    if (isLogSlowRPC()) {
-      logSlowRpcCalls(name, processingTime);
+    if (!deferredCall) {
+      rpcMetrics.addRpcProcessingTime(processingTime);
+      rpcDetailedMetrics.addProcessingTime(name, processingTime);
+      callQueue.addResponseTime(name, getPriorityLevel(), queueTime,
+          processingTime);
+      if (isLogSlowRPC()) {
+        logSlowRpcCalls(name, processingTime);
+      }
     }
   }
 
+  void updateDeferredMetrics(String name, long processingTime) {
+    rpcMetrics.addDeferredRpcProcessingTime(processingTime);
+    rpcDetailedMetrics.addDeferredProcessingTime(name, processingTime);
+  }
+
   /**
    * A convenience method to bind to a given address and report 
    * better exceptions if the address is not a valid host.
@@ -674,6 +681,7 @@ public abstract class Server {
     final byte[] clientId;
     private final TraceScope traceScope; // the HTrace scope on the server side
     private final CallerContext callerContext; // the call context
+    private boolean deferredResponse = false;
     private int priorityLevel;
     // the priority level assigned by scheduler, 0 by default
 
@@ -782,6 +790,22 @@ public abstract class Server {
     public void setPriorityLevel(int priorityLevel) {
       this.priorityLevel = priorityLevel;
     }
+
+    @InterfaceStability.Unstable
+    public void deferResponse() {
+      this.deferredResponse = true;
+    }
+
+    @InterfaceStability.Unstable
+    public boolean isResponseDeferred() {
+      return this.deferredResponse;
+    }
+
+    public void setDeferredResponse(Writable response) {
+    }
+
+    public void setDeferredError(Throwable t) {
+    }
   }
 
   /** A RPC extended call queued for handling. */
@@ -835,43 +859,58 @@ public abstract class Server {
         Server.LOG.info(Thread.currentThread().getName() + ": skipped " + this);
         return null;
       }
-      String errorClass = null;
-      String error = null;
-      RpcStatusProto returnStatus = RpcStatusProto.SUCCESS;
-      RpcErrorCodeProto detailedErr = null;
       Writable value = null;
+      ResponseParams responseParams = new ResponseParams();
 
       try {
         value = call(
             rpcKind, connection.protocolName, rpcRequest, timestamp);
       } catch (Throwable e) {
-        if (e instanceof UndeclaredThrowableException) {
-          e = e.getCause();
-        }
-        logException(Server.LOG, e, this);
-        if (e instanceof RpcServerException) {
-          RpcServerException rse = ((RpcServerException)e);
-          returnStatus = rse.getRpcStatusProto();
-          detailedErr = rse.getRpcErrorCodeProto();
-        } else {
-          returnStatus = RpcStatusProto.ERROR;
-          detailedErr = RpcErrorCodeProto.ERROR_APPLICATION;
-        }
-        errorClass = e.getClass().getName();
-        error = StringUtils.stringifyException(e);
-        // Remove redundant error class name from the beginning of the
-        // stack trace
-        String exceptionHdr = errorClass + ": ";
-        if (error.startsWith(exceptionHdr)) {
-          error = error.substring(exceptionHdr.length());
+        populateResponseParamsOnError(e, responseParams);
+      }
+      if (!isResponseDeferred()) {
+        setupResponse(this, responseParams.returnStatus,
+            responseParams.detailedErr,
+            value, responseParams.errorClass, responseParams.error);
+        sendResponse();
+      } else {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Deferring response for callId: " + this.callId);
         }
       }
-      setupResponse(this, returnStatus, detailedErr,
-          value, errorClass, error);
-      sendResponse();
       return null;
     }
 
+    /**
+     * @param t              the {@link java.lang.Throwable} to use to set
+     *                       errorInfo
+     * @param responseParams the {@link ResponseParams} instance to populate
+     */
+    private void populateResponseParamsOnError(Throwable t,
+                                               ResponseParams responseParams) {
+      if (t instanceof UndeclaredThrowableException) {
+        t = t.getCause();
+      }
+      logException(Server.LOG, t, this);
+      if (t instanceof RpcServerException) {
+        RpcServerException rse = ((RpcServerException) t);
+        responseParams.returnStatus = rse.getRpcStatusProto();
+        responseParams.detailedErr = rse.getRpcErrorCodeProto();
+      } else {
+        responseParams.returnStatus = RpcStatusProto.ERROR;
+        responseParams.detailedErr = RpcErrorCodeProto.ERROR_APPLICATION;
+      }
+      responseParams.errorClass = t.getClass().getName();
+      responseParams.error = StringUtils.stringifyException(t);
+      // Remove redundant error class name from the beginning of the
+      // stack trace
+      String exceptionHdr = responseParams.errorClass + ": ";
+      if (responseParams.error.startsWith(exceptionHdr)) {
+        responseParams.error =
+            responseParams.error.substring(exceptionHdr.length());
+      }
+    }
+
     void setResponse(ByteBuffer response) throws IOException {
       this.rpcResponse = response;
     }
@@ -891,6 +930,80 @@ public abstract class Server {
       connection.sendResponse(call);
     }
 
+    /**
+     * Send a deferred response, ignoring errors.
+     */
+    private void sendDeferedResponse() {
+      try {
+        connection.sendResponse(this);
+      } catch (Exception e) {
+        // For synchronous calls, application code is done once it's returned
+        // from a method. It does not expect to receive an error.
+        // This is equivalent to what happens in synchronous calls when the
+        // Responder is not able to send out the response.
+        LOG.error("Failed to send deferred response. ThreadName=" + Thread
+            .currentThread().getName() + ", CallId="
+            + callId + ", hostname=" + getHostAddress());
+      }
+    }
+
+    @Override
+    public void setDeferredResponse(Writable response) {
+      if (this.connection.getServer().running) {
+        try {
+          setupResponse(this, RpcStatusProto.SUCCESS, null, response,
+              null, null);
+        } catch (IOException e) {
+          // For synchronous calls, application code is done once it has
+          // returned from a method. It does not expect to receive an error.
+          // This is equivalent to what happens in synchronous calls when the
+          // response cannot be sent.
+          LOG.error(
+              "Failed to setup deferred successful response. ThreadName=" +
+                  Thread.currentThread().getName() + ", Call=" + this);
+          return;
+        }
+        sendDeferedResponse();
+      }
+    }
+
+    @Override
+    public void setDeferredError(Throwable t) {
+      if (this.connection.getServer().running) {
+        if (t == null) {
+          t = new IOException(
+              "User code indicated an error without an exception");
+        }
+        try {
+          ResponseParams responseParams = new ResponseParams();
+          populateResponseParamsOnError(t, responseParams);
+          setupResponse(this, responseParams.returnStatus,
+              responseParams.detailedErr,
+              null, responseParams.errorClass, responseParams.error);
+        } catch (IOException e) {
+          // For synchronous calls, application code is done once it has
+          // returned from a method. It does not expect to receive an error.
+          // This is equivalent to what happens in synchronous calls when the
+          // response cannot be sent.
+          LOG.error(
+              "Failed to setup deferred error response. ThreadName=" +
+                  Thread.currentThread().getName() + ", Call=" + this);
+        }
+        sendDeferedResponse();
+      }
+    }
+
+    /**
+     * Holds response parameters. Defaults set to work for successful
+     * invocations
+     */
+    private class ResponseParams {
+      String errorClass = null;
+      String error = null;
+      RpcErrorCodeProto detailedErr = null;
+      RpcStatusProto returnStatus = RpcStatusProto.SUCCESS;
+    }
+
     @Override
     public String toString() {
       return super.toString() + " " + rpcRequest + " from " + connection;
@@ -1589,6 +1702,10 @@ public abstract class Server {
       return lastContact;
     }
 
+    public Server getServer() {
+      return Server.this;
+    }
+
     /* Return true if the connection has no outstanding rpc */
     private boolean isIdle() {
       return rpcCount.get() == 0;
@@ -2133,8 +2250,21 @@ public abstract class Server {
     }
     
     /**
+<<<<<<< HEAD
      * Process an RPC Request - handle connection setup and decoding of
      * request into a Call
+=======
+     * Process one RPC Request from buffer read from socket stream 
+     *  - decode rpc in a rpc-Call
+     *  - handle out-of-band RPC requests such as the initial connectionContext
+     *  - A successfully decoded RpcCall will be deposited in RPC-Q and
+     *    its response will be sent later when the request is processed.
+     * 
+     * Prior to this call the connectionHeader ("hrpc...") has been handled and
+     * if SASL then SASL has been established and the buf we are passed
+     * has been unwrapped from SASL.
+     * 
+>>>>>>> 3d94da1... HADOOP-11552. Allow handoff on the server side for RPC requests. Contributed by Siddharth Seth
      * @param bb - contains the RPC request header and the rpc request
      * @throws IOException - internal error that should not be returned to
      *         client, typically failure to respond to client
@@ -2355,8 +2485,11 @@ public abstract class Server {
     
     /**
      * Decode the a protobuf from the given input stream 
+<<<<<<< HEAD
      * @param message - Representation of the type of message
      * @param buffer - a buffer to read the protobuf
+=======
+>>>>>>> 3d94da1... HADOOP-11552. Allow handoff on the server side for RPC requests. Contributed by Siddharth Seth
      * @return Message - decoded protobuf
      * @throws WrappedRpcServerException - deserialization failed
      */
@@ -2660,11 +2793,10 @@ public abstract class Server {
   private void closeConnection(Connection connection) {
     connectionManager.close(connection);
   }
-  
+
   /**
    * Setup response for the IPC Call.
    * 
-   * @param responseBuf buffer to serialize the response into
    * @param call {@link Call} to which we are setting up the response
    * @param status of the IPC call
    * @param rv return value for the IPC Call, if the call was successful

+ 2 - 1
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java

@@ -549,7 +549,8 @@ public class WritableRpcEngine implements RpcEngine {
           String detailedMetricsName = (exception == null) ?
               call.getMethodName() :
               exception.getClass().getSimpleName();
-          server.updateMetrics(detailedMetricsName, qTime, processingTime);
+          server
+              .updateMetrics(detailedMetricsName, qTime, processingTime, false);
         }
       }
     }

+ 6 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/metrics/RpcDetailedMetrics.java

@@ -35,6 +35,7 @@ import org.apache.hadoop.metrics2.lib.MutableRatesWithAggregation;
 public class RpcDetailedMetrics {
 
   @Metric MutableRatesWithAggregation rates;
+  @Metric MutableRatesWithAggregation deferredRpcRates;
 
   static final Log LOG = LogFactory.getLog(RpcDetailedMetrics.class);
   final MetricsRegistry registry;
@@ -60,6 +61,7 @@ public class RpcDetailedMetrics {
    */
   public void init(Class<?> protocol) {
     rates.init(protocol);
+    deferredRpcRates.init(protocol);
   }
 
   /**
@@ -72,6 +74,10 @@ public class RpcDetailedMetrics {
     rates.add(name, processingTime);
   }
 
+  public void addDeferredProcessingTime(String name, long processingTime) {
+    deferredRpcRates.add(name, processingTime);
+  }
+
   /**
    * Shutdown the instrumentation for the process
    */

+ 33 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/metrics/RpcMetrics.java

@@ -61,6 +61,8 @@ public class RpcMetrics {
           new MutableQuantiles[intervals.length];
       rpcProcessingTimeMillisQuantiles =
           new MutableQuantiles[intervals.length];
+      deferredRpcProcessingTimeMillisQuantiles =
+          new MutableQuantiles[intervals.length];
       for (int i = 0; i < intervals.length; i++) {
         int interval = intervals[i];
         rpcQueueTimeMillisQuantiles[i] = registry.newQuantiles("rpcQueueTime"
@@ -69,6 +71,10 @@ public class RpcMetrics {
         rpcProcessingTimeMillisQuantiles[i] = registry.newQuantiles(
             "rpcProcessingTime" + interval + "s",
             "rpc processing time in milli second", "ops", "latency", interval);
+        deferredRpcProcessingTimeMillisQuantiles[i] = registry
+            .newQuantiles("deferredRpcProcessingTime" + interval + "s",
+                "deferred rpc processing time in milli seconds", "ops",
+                "latency", interval);
       }
     }
     LOG.debug("Initialized " + registry);
@@ -87,6 +93,8 @@ public class RpcMetrics {
   MutableQuantiles[] rpcQueueTimeMillisQuantiles;
   @Metric("Processing time") MutableRate rpcProcessingTime;
   MutableQuantiles[] rpcProcessingTimeMillisQuantiles;
+  @Metric("Deferred Processing time") MutableRate deferredRpcProcessingTime;
+  MutableQuantiles[] deferredRpcProcessingTimeMillisQuantiles;
   @Metric("Number of authentication failures")
   MutableCounterLong rpcAuthenticationFailures;
   @Metric("Number of authentication successes")
@@ -202,6 +210,15 @@ public class RpcMetrics {
     }
   }
 
+  public void addDeferredRpcProcessingTime(long processingTime) {
+    deferredRpcProcessingTime.add(processingTime);
+    if (rpcQuantileEnable) {
+      for (MutableQuantiles q : deferredRpcProcessingTimeMillisQuantiles) {
+        q.add(processingTime);
+      }
+    }
+  }
+
   /**
    * One client backoff event
    */
@@ -255,4 +272,20 @@ public class RpcMetrics {
   public long getRpcSlowCalls() {
     return rpcSlowCalls.value();
   }
+
+  public MutableRate getDeferredRpcProcessingTime() {
+    return deferredRpcProcessingTime;
+  }
+
+  public long getDeferredRpcProcessingSampleCount() {
+    return deferredRpcProcessingTime.lastStat().numSamples();
+  }
+
+  public double getDeferredRpcProcessingMean() {
+    return deferredRpcProcessingTime.lastStat().mean();
+  }
+
+  public double getDeferredRpcProcessingStdDev() {
+    return deferredRpcProcessingTime.lastStat().stddev();
+  }
 }

+ 167 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestProtoBufRpcServerHandoff.java

@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import java.net.InetSocketAddress;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletionService;
+import java.util.concurrent.ExecutorCompletionService;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+import com.google.protobuf.BlockingService;
+import com.google.protobuf.RpcController;
+import com.google.protobuf.ServiceException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.ipc.protobuf.TestProtos;
+import org.apache.hadoop.ipc.protobuf.TestRpcServiceProtos.TestProtobufRpcHandoffProto;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestProtoBufRpcServerHandoff {
+
+  public static final Log LOG =
+      LogFactory.getLog(TestProtoBufRpcServerHandoff.class);
+
+  @Test(timeout = 20000)
+  public void test() throws Exception {
+    Configuration conf = new Configuration();
+
+    TestProtoBufRpcServerHandoffServer serverImpl =
+        new TestProtoBufRpcServerHandoffServer();
+    BlockingService blockingService =
+        TestProtobufRpcHandoffProto.newReflectiveBlockingService(serverImpl);
+
+    RPC.setProtocolEngine(conf, TestProtoBufRpcServerHandoffProtocol.class,
+        ProtobufRpcEngine.class);
+    RPC.Server server = new RPC.Builder(conf)
+        .setProtocol(TestProtoBufRpcServerHandoffProtocol.class)
+        .setInstance(blockingService)
+        .setVerbose(true)
+        .setNumHandlers(1) // Num Handlers explicitly set to 1 for test.
+        .build();
+    server.start();
+
+    InetSocketAddress address = server.getListenerAddress();
+    long serverStartTime = System.currentTimeMillis();
+    LOG.info("Server started at: " + address + " at time: " + serverStartTime);
+
+    final TestProtoBufRpcServerHandoffProtocol client = RPC.getProxy(
+        TestProtoBufRpcServerHandoffProtocol.class, 1, address, conf);
+
+    ExecutorService executorService = Executors.newFixedThreadPool(2);
+    CompletionService<ClientInvocationCallable> completionService =
+        new ExecutorCompletionService<ClientInvocationCallable>(
+            executorService);
+
+    completionService.submit(new ClientInvocationCallable(client, 5000l));
+    completionService.submit(new ClientInvocationCallable(client, 5000l));
+
+    long submitTime = System.currentTimeMillis();
+    Future<ClientInvocationCallable> future1 = completionService.take();
+    Future<ClientInvocationCallable> future2 = completionService.take();
+
+    ClientInvocationCallable callable1 = future1.get();
+    ClientInvocationCallable callable2 = future2.get();
+
+    LOG.info(callable1);
+    LOG.info(callable2);
+
+    // Ensure the 5 second sleep responses are within a reasonable time of each
+    // other.
+    Assert.assertTrue(Math.abs(callable1.endTime - callable2.endTime) < 2000l);
+    Assert.assertTrue(System.currentTimeMillis() - submitTime < 7000l);
+
+  }
+
+  private static class ClientInvocationCallable
+      implements Callable<ClientInvocationCallable> {
+    final TestProtoBufRpcServerHandoffProtocol client;
+    final long sleepTime;
+    TestProtos.SleepResponseProto2 result;
+    long startTime;
+    long endTime;
+
+
+    private ClientInvocationCallable(
+        TestProtoBufRpcServerHandoffProtocol client, long sleepTime) {
+      this.client = client;
+      this.sleepTime = sleepTime;
+    }
+
+    @Override
+    public ClientInvocationCallable call() throws Exception {
+      startTime = System.currentTimeMillis();
+      result = client.sleep(null,
+          TestProtos.SleepRequestProto2.newBuilder().setSleepTime(sleepTime)
+              .build());
+      endTime = System.currentTimeMillis();
+      return this;
+    }
+
+    @Override
+    public String toString() {
+      return "startTime=" + startTime + ", endTime=" + endTime +
+          (result != null ?
+              ", result.receiveTime=" + result.getReceiveTime() +
+                  ", result.responseTime=" +
+                  result.getResponseTime() : "");
+    }
+  }
+
+  @ProtocolInfo(
+      protocolName = "org.apache.hadoop.ipc.TestProtoBufRpcServerHandoff$TestProtoBufRpcServerHandoffProtocol",
+      protocolVersion = 1)
+  public interface TestProtoBufRpcServerHandoffProtocol
+      extends TestProtobufRpcHandoffProto.BlockingInterface {
+  }
+
+  public static class TestProtoBufRpcServerHandoffServer
+      implements TestProtoBufRpcServerHandoffProtocol {
+
+    @Override
+    public TestProtos.SleepResponseProto2 sleep
+        (RpcController controller,
+         TestProtos.SleepRequestProto2 request) throws
+        ServiceException {
+      final long startTime = System.currentTimeMillis();
+      final ProtobufRpcEngineCallback callback =
+          ProtobufRpcEngine.Server.registerForDeferredResponse();
+      final long sleepTime = request.getSleepTime();
+      new Thread() {
+        @Override
+        public void run() {
+          try {
+            Thread.sleep(sleepTime);
+          } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+          }
+          callback.setResponse(
+              TestProtos.SleepResponseProto2.newBuilder()
+                  .setReceiveTime(startTime)
+                  .setResponseTime(System.currentTimeMillis()).build());
+        }
+      }.start();
+      return null;
+    }
+  }
+}

+ 218 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRpcServerHandoff.java

@@ -0,0 +1,218 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.net.NetUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestRpcServerHandoff {
+
+  public static final Log LOG =
+      LogFactory.getLog(TestRpcServerHandoff.class);
+
+  private static final String BIND_ADDRESS = "0.0.0.0";
+  private static final Configuration conf = new Configuration();
+
+
+  public static class ServerForHandoffTest extends Server {
+
+    private final AtomicBoolean invoked = new AtomicBoolean(false);
+    private final ReentrantLock lock = new ReentrantLock();
+    private final Condition invokedCondition = lock.newCondition();
+
+    private volatile Writable request;
+    private volatile Call deferredCall;
+
+    protected ServerForHandoffTest(int handlerCount) throws IOException {
+      super(BIND_ADDRESS, 0, BytesWritable.class, handlerCount, conf);
+    }
+
+    @Override
+    public Writable call(RPC.RpcKind rpcKind, String protocol, Writable param,
+                         long receiveTime) throws Exception {
+      request = param;
+      deferredCall = Server.getCurCall().get();
+      Server.getCurCall().get().deferResponse();
+      lock.lock();
+      try {
+        invoked.set(true);
+        invokedCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+      return null;
+    }
+
+    void awaitInvocation() throws InterruptedException {
+      lock.lock();
+      try {
+        while (!invoked.get()) {
+          invokedCondition.await();
+        }
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    void sendResponse() {
+      deferredCall.setDeferredResponse(request);
+    }
+
+    void sendError() {
+      deferredCall.setDeferredError(new IOException("DeferredError"));
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void testDeferredResponse() throws IOException, InterruptedException,
+      ExecutionException {
+
+
+    ServerForHandoffTest server = new ServerForHandoffTest(2);
+    server.start();
+    try {
+      InetSocketAddress serverAddress = NetUtils.getConnectAddress(server);
+      byte[] requestBytes = generateRandomBytes(1024);
+      ClientCallable clientCallable =
+          new ClientCallable(serverAddress, conf, requestBytes);
+
+      FutureTask<Writable> future = new FutureTask<Writable>(clientCallable);
+      Thread clientThread = new Thread(future);
+      clientThread.start();
+
+      server.awaitInvocation();
+      awaitResponseTimeout(future);
+
+      server.sendResponse();
+      BytesWritable response = (BytesWritable) future.get();
+
+      Assert.assertEquals(new BytesWritable(requestBytes), response);
+    } finally {
+      if (server != null) {
+        server.stop();
+      }
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void testDeferredException() throws IOException, InterruptedException,
+      ExecutionException {
+    ServerForHandoffTest server = new ServerForHandoffTest(2);
+    server.start();
+    try {
+      InetSocketAddress serverAddress = NetUtils.getConnectAddress(server);
+      byte[] requestBytes = generateRandomBytes(1024);
+      ClientCallable clientCallable =
+          new ClientCallable(serverAddress, conf, requestBytes);
+
+      FutureTask<Writable> future = new FutureTask<Writable>(clientCallable);
+      Thread clientThread = new Thread(future);
+      clientThread.start();
+
+      server.awaitInvocation();
+      awaitResponseTimeout(future);
+
+      server.sendError();
+      try {
+        future.get();
+        Assert.fail("Call succeeded. Was expecting an exception");
+      } catch (ExecutionException e) {
+        Throwable cause = e.getCause();
+        Assert.assertTrue(cause instanceof RemoteException);
+        RemoteException re = (RemoteException) cause;
+        Assert.assertTrue(re.toString().contains("DeferredError"));
+      }
+    } finally {
+      if (server != null) {
+        server.stop();
+      }
+    }
+  }
+
+  private void awaitResponseTimeout(FutureTask<Writable> future) throws
+      ExecutionException,
+      InterruptedException {
+    long sleepTime = 3000L;
+    while (sleepTime > 0) {
+      try {
+        future.get(200L, TimeUnit.MILLISECONDS);
+        Assert.fail("Expected to timeout since" +
+            " the deferred response hasn't been registered");
+      } catch (TimeoutException e) {
+        // Ignoring. Expected to time out.
+      }
+      sleepTime -= 200L;
+    }
+    LOG.info("Done sleeping");
+  }
+
+  private static class ClientCallable implements Callable<Writable> {
+
+    private final InetSocketAddress address;
+    private final Configuration conf;
+    final byte[] requestBytes;
+
+
+    private ClientCallable(InetSocketAddress address, Configuration conf,
+                           byte[] requestBytes) {
+      this.address = address;
+      this.conf = conf;
+      this.requestBytes = requestBytes;
+    }
+
+    @Override
+    public Writable call() throws Exception {
+      Client client = new Client(BytesWritable.class, conf);
+      Writable param = new BytesWritable(requestBytes);
+      final Client.ConnectionId remoteId =
+          Client.ConnectionId.getConnectionId(address, null,
+              null, 0, null, conf);
+      Writable result = client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId,
+          new AtomicBoolean(false));
+      return result;
+    }
+  }
+
+  private byte[] generateRandomBytes(int length) {
+    Random random = new Random();
+    byte[] bytes = new byte[length];
+    for (int i = 0; i < length; i++) {
+      bytes[i] = (byte) ('a' + random.nextInt(26));
+    }
+    return bytes;
+  }
+}

+ 10 - 1
hadoop-common-project/hadoop-common/src/test/proto/test.proto

@@ -90,4 +90,13 @@ message AuthMethodResponseProto {
 
 message AuthUserResponseProto {
   required string authUser = 1;
-}
+}
+
+message SleepRequestProto2 {
+  optional int64 sleep_time = 1;
+}
+
+message SleepResponseProto2 {
+  optional int64 receive_time = 1;
+  optional int64 response_time = 2;
+}

+ 4 - 0
hadoop-common-project/hadoop-common/src/test/proto/test_rpc_service.proto

@@ -65,3 +65,7 @@ service NewerProtobufRpcProto {
   rpc ping(EmptyRequestProto) returns (EmptyResponseProto);
   rpc echo(EmptyRequestProto) returns (EmptyResponseProto);
 }
+
+service TestProtobufRpcHandoffProto {
+  rpc sleep(SleepRequestProto2) returns (SleepResponseProto2);
+}