Przeglądaj źródła

HADOOP-19235. IPC client uses CompletableFuture to support asynchronous operations. (#6888)

Jian Zhang 9 miesięcy temu
rodzic
commit
f10ef7d70a

+ 82 - 74
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java

@@ -96,8 +96,8 @@ public class Client implements AutoCloseable {
   private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>();
   private static final ThreadLocal<Object> EXTERNAL_CALL_HANDLER
       = new ThreadLocal<>();
-  private static final ThreadLocal<AsyncGet<? extends Writable, IOException>>
-      ASYNC_RPC_RESPONSE = new ThreadLocal<>();
+  private static final ThreadLocal<CompletableFuture<Writable>> ASYNC_RPC_RESPONSE
+      = new ThreadLocal<>();
   private static final ThreadLocal<Boolean> asynchronousMode =
       new ThreadLocal<Boolean>() {
         @Override
@@ -110,7 +110,46 @@ public class Client implements AutoCloseable {
   @Unstable
   public static <T extends Writable> AsyncGet<T, IOException>
       getAsyncRpcResponse() {
-    return (AsyncGet<T, IOException>) ASYNC_RPC_RESPONSE.get();
+    CompletableFuture<Writable> responseFuture = ASYNC_RPC_RESPONSE.get();
+    return new AsyncGet<T, IOException>() {
+      @Override
+      public T get(long timeout, TimeUnit unit)
+          throws IOException, TimeoutException, InterruptedException {
+        try {
+          if (unit == null || timeout < 0) {
+            return (T) responseFuture.get();
+          }
+          return (T) responseFuture.get(timeout, unit);
+        } catch (ExecutionException e) {
+          Throwable cause = e.getCause();
+          if (cause instanceof IOException) {
+            throw (IOException) cause;
+          }
+          throw new IllegalStateException(e);
+        }
+      }
+
+      @Override
+      public boolean isDone() {
+        return responseFuture.isDone();
+      }
+    };
+  }
+
+  /**
+   * Retrieves the current response future from the thread-local storage.
+   *
+   * @return A {@link CompletableFuture} of type T that represents the
+   *         asynchronous operation. If no response future is present in
+   *         the thread-local storage, this method returns {@code null}.
+   * @param <T> The type of the value completed by the returned
+   *            {@link CompletableFuture}. It must be a subclass of
+   *            {@link Writable}.
+   * @see CompletableFuture
+   * @see Writable
+   */
+  public static <T extends Writable> CompletableFuture<T> getResponseFuture() {
+    return (CompletableFuture<T>) ASYNC_RPC_RESPONSE.get();
   }
 
   /**
@@ -277,10 +316,8 @@ public class Client implements AutoCloseable {
     final int id;               // call id
     final int retry;           // retry count
     final Writable rpcRequest;  // the serialized rpc request
-    Writable rpcResponse;       // null if rpc has error
-    IOException error;          // exception, null if success
+    private final CompletableFuture<Writable> rpcResponseFuture;
     final RPC.RpcKind rpcKind;      // Rpc EngineKind
-    boolean done;               // true when call is done
     private final Object externalHandler;
     private AlignmentContext alignmentContext;
 
@@ -304,6 +341,7 @@ public class Client implements AutoCloseable {
       }
 
       this.externalHandler = EXTERNAL_CALL_HANDLER.get();
+      this.rpcResponseFuture = new CompletableFuture<>();
     }
 
     @Override
@@ -314,9 +352,6 @@ public class Client implements AutoCloseable {
     /** Indicate when the call is complete and the
      * value or error are available.  Notifies by default.  */
     protected synchronized void callComplete() {
-      this.done = true;
-      notify();                                 // notify caller
-
       if (externalHandler != null) {
         synchronized (externalHandler) {
           externalHandler.notify();
@@ -339,7 +374,7 @@ public class Client implements AutoCloseable {
      * @param error exception thrown by the call; either local or remote
      */
     public synchronized void setException(IOException error) {
-      this.error = error;
+      rpcResponseFuture.completeExceptionally(error);
       callComplete();
     }
     
@@ -349,13 +384,9 @@ public class Client implements AutoCloseable {
      * @param rpcResponse return value of the rpc call.
      */
     public synchronized void setRpcResponse(Writable rpcResponse) {
-      this.rpcResponse = rpcResponse;
+      rpcResponseFuture.complete(rpcResponse);
       callComplete();
     }
-    
-    public synchronized Writable getRpcResponse() {
-      return rpcResponse;
-    }
   }
 
   /** Thread that reads responses and notifies callers.  Each connection owns a
@@ -1495,39 +1526,19 @@ public class Client implements AutoCloseable {
     }
 
     if (isAsynchronousMode()) {
-      final AsyncGet<Writable, IOException> asyncGet
-          = new AsyncGet<Writable, IOException>() {
-        @Override
-        public Writable get(long timeout, TimeUnit unit)
-            throws IOException, TimeoutException{
-          boolean done = true;
-          try {
-            final Writable w = getRpcResponse(call, connection, timeout, unit);
-            if (w == null) {
-              done = false;
-              throw new TimeoutException(call + " timed out "
-                  + timeout + " " + unit);
-            }
-            return w;
-          } finally {
-            if (done) {
-              releaseAsyncCall();
+      CompletableFuture<Writable> result = call.rpcResponseFuture.handle(
+          (rpcResponse, e) -> {
+            releaseAsyncCall();
+            if (e != null) {
+              IOException ioe = (IOException) e;
+              throw new CompletionException(warpIOException(ioe, connection));
             }
-          }
-        }
-
-        @Override
-        public boolean isDone() {
-          synchronized (call) {
-            return call.done;
-          }
-        }
-      };
-
-      ASYNC_RPC_RESPONSE.set(asyncGet);
+            return rpcResponse;
+          });
+      ASYNC_RPC_RESPONSE.set(result);
       return null;
     } else {
-      return getRpcResponse(call, connection, -1, null);
+      return getRpcResponse(call, connection);
     }
   }
 
@@ -1564,37 +1575,34 @@ public class Client implements AutoCloseable {
   }
 
   /** @return the rpc response or, in case of timeout, null. */
-  private Writable getRpcResponse(final Call call, final Connection connection,
-      final long timeout, final TimeUnit unit) throws IOException {
-    synchronized (call) {
-      while (!call.done) {
-        try {
-          AsyncGet.Util.wait(call, timeout, unit);
-          if (timeout >= 0 && !call.done) {
-            return null;
-          }
-        } catch (InterruptedException ie) {
-          Thread.currentThread().interrupt();
-          throw new InterruptedIOException("Call interrupted");
-        }
+  private Writable getRpcResponse(final Call call, final Connection connection)
+      throws IOException {
+    try {
+      return call.rpcResponseFuture.get();
+    } catch (InterruptedException ie) {
+      Thread.currentThread().interrupt();
+      throw new InterruptedIOException("Call interrupted");
+    } catch (ExecutionException e) {
+      Throwable cause = e.getCause();
+      if (cause instanceof IOException) {
+        throw warpIOException((IOException) cause, connection);
       }
+      throw new IllegalStateException(e);
+    }
+  }
 
-      if (call.error != null) {
-        if (call.error instanceof RemoteException ||
-            call.error instanceof SaslException) {
-          call.error.fillInStackTrace();
-          throw call.error;
-        } else { // local exception
-          InetSocketAddress address = connection.getRemoteAddress();
-          throw NetUtils.wrapException(address.getHostName(),
-                  address.getPort(),
-                  NetUtils.getHostname(),
-                  0,
-                  call.error);
-        }
-      } else {
-        return call.getRpcResponse();
-      }
+  private IOException warpIOException(IOException ioe, Connection connection) {
+    if (ioe instanceof RemoteException ||
+        ioe instanceof SaslException) {
+      ioe.fillInStackTrace();
+      return ioe;
+    } else { // local exception
+      InetSocketAddress address = connection.getRemoteAddress();
+      return NetUtils.wrapException(address.getHostName(),
+          address.getPort(),
+          NetUtils.getHostname(),
+          0,
+          ioe);
     }
   }
 

+ 92 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java

@@ -28,6 +28,7 @@ import org.apache.hadoop.ipc.TestIPC.TestServer;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.util.Time;
 import org.apache.hadoop.util.concurrent.AsyncGetFuture;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -38,11 +39,14 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.util.*;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
+import static org.junit.jupiter.api.Assertions.fail;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 
@@ -137,6 +141,60 @@ public class TestAsyncIPC {
     }
   }
 
+  /**
+   * For testing the asynchronous calls of the RPC client
+   * implemented with CompletableFuture.
+   */
+  static class AsyncCompletableFutureCaller extends Thread {
+    private final Client client;
+    private final InetSocketAddress server;
+    private final int count;
+    private final List<CompletableFuture<Writable>> completableFutures;
+    private final List<Long> expectedValues;
+
+    AsyncCompletableFutureCaller(Client client, InetSocketAddress server, int count) {
+      this.client = client;
+      this.server = server;
+      this.count = count;
+      this.completableFutures = new ArrayList<>(count);
+      this.expectedValues = new ArrayList<>(count);
+      setName("Async CompletableFuture Caller");
+    }
+
+    @Override
+    public void run() {
+      // Set the RPC client to use asynchronous mode.
+      Client.setAsynchronousMode(true);
+      long startTime = Time.monotonicNow();
+      try {
+        for (int i = 0; i < count; i++) {
+          final long param = TestIPC.RANDOM.nextLong();
+          TestIPC.call(client, param, server, conf);
+          expectedValues.add(param);
+          completableFutures.add(Client.getResponseFuture());
+        }
+        // Since the run method is asynchronous,
+        // it does not need to wait for a response after sending a request,
+        // so the time taken by the run method is less than count * 100
+        // (where 100 is the time taken by the server to process a request).
+        long cost = Time.monotonicNow() - startTime;
+        assertTrue(cost < count * 100L);
+        LOG.info("[{}] run cost {}ms", Thread.currentThread().getName(), cost);
+      } catch (Exception e) {
+        fail();
+      }
+    }
+
+    public void assertReturnValues()
+        throws InterruptedException, ExecutionException {
+      for (int i = 0; i < count; i++) {
+        LongWritable value = (LongWritable) completableFutures.get(i).get();
+        assertEquals(expectedValues.get(i).longValue(), value.get(),
+            "call" + i + " failed.");
+      }
+    }
+  }
+
   static class AsyncLimitlCaller extends Thread {
     private Client client;
     private InetSocketAddress server;
@@ -547,4 +605,38 @@ public class TestAsyncIPC {
       assertEquals(startID + i, callIds.get(i).intValue());
     }
   }
+
+  @Test
+  @Timeout(value = 60)
+  public void testAsyncCallWithCompletableFuture() throws IOException,
+      InterruptedException, ExecutionException {
+    // Override client to store the call id
+    final Client client = new Client(LongWritable.class, conf);
+
+    // Construct an RPC server, which includes a handler thread.
+    final TestServer server = new TestIPC.TestServer(1, false, conf);
+    server.callListener = () -> {
+      try {
+        // The server requires at least 100 milliseconds to process a request.
+        Thread.sleep(100);
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    };
+
+    try {
+      InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      server.start();
+      // Send 10 asynchronous requests.
+      final AsyncCompletableFutureCaller caller =
+          new AsyncCompletableFutureCaller(client, addr, 10);
+      caller.start();
+      caller.join();
+      // Check if the values returned by the asynchronous call meet the expected values.
+      caller.assertReturnValues();
+    } finally {
+      client.stop();
+      server.stop();
+    }
+  }
 }