Quellcode durchsuchen

HADOOP-12483. Maintain wrapped SASL ordering for postponed IPC responses. (Daryn Sharp via yliu)

yliu vor 9 Jahren
Ursprung
Commit
476a251e5e

+ 3 - 0
hadoop-common-project/hadoop-common/CHANGES.txt

@@ -1243,6 +1243,9 @@ Release 2.8.0 - UNRELEASED
     HADOOP-10941. Proxy user verification NPEs if remote host is unresolvable.
     (Benoy Antony via stevel).
 
+    HADOOP-12483. Maintain wrapped SASL ordering for postponed IPC responses.
+    (Daryn Sharp via yliu)
+
   OPTIMIZATIONS
 
     HADOOP-12051. ProtobufRpcEngine.invoke() should use Exception.toString()

+ 34 - 30
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java

@@ -584,6 +584,11 @@ public abstract class Server {
     private final byte[] clientId;
     private final TraceScope traceScope; // the HTrace scope on the server side
 
+    private Call(Call call) {
+      this(call.callId, call.retryCount, call.rpcRequest, call.connection,
+          call.rpcKind, call.clientId, call.traceScope);
+    }
+
     public Call(int id, int retryCount, Writable param, 
         Connection connection) {
       this(id, retryCount, param, connection, RPC.RpcKind.RPC_BUILTIN,
@@ -614,12 +619,6 @@ public abstract class Server {
           + retryCount;
     }
 
-    public void setResponse(Throwable t) throws IOException {
-      setupResponse(new ByteArrayOutputStream(), this,
-          RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
-          null, t.getClass().getName(), StringUtils.stringifyException(t));
-    }
-
     public void setResponse(ByteBuffer response) {
       this.rpcResponse = response;
     }
@@ -644,14 +643,23 @@ public abstract class Server {
       int count = responseWaitCount.decrementAndGet();
       assert count >= 0 : "response has already been sent";
       if (count == 0) {
-        if (rpcResponse == null) {
-          // needed by postponed operations to indicate an exception has
-          // occurred.  it's too late to re-encode the response so just
-          // drop the connection.
-          connection.close();
-        } else {
-          connection.sendResponse(this);
-        }
+        connection.sendResponse(this);
+      }
+    }
+
+    @InterfaceStability.Unstable
+    @InterfaceAudience.LimitedPrivate({"HDFS"})
+    public void abortResponse(Throwable t) throws IOException {
+      // don't send response if the call was already sent or aborted.
+      if (responseWaitCount.getAndSet(-1) > 0) {
+        // clone the call to prevent a race with the other thread stomping
+        // on the response while being sent.  the original call is
+        // effectively discarded since the wait count won't hit zero
+        Call call = new Call(this);
+        setupResponse(new ByteArrayOutputStream(), call,
+            RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER,
+            null, t.getClass().getName(), StringUtils.stringifyException(t));
+        call.sendResponse();
       }
     }
 
@@ -1156,6 +1164,13 @@ public abstract class Server {
     //
     void doRespond(Call call) throws IOException {
       synchronized (call.connection.responseQueue) {
+        // must only wrap before adding to the responseQueue to prevent
+        // postponed responses from being encrypted and sent out of order.
+        if (call.connection.useWrap) {
+          ByteArrayOutputStream response = new ByteArrayOutputStream();
+          wrapWithSasl(response, call);
+          call.setResponse(ByteBuffer.wrap(response.toByteArray()));
+        }
         call.connection.responseQueue.addLast(call);
         if (call.connection.responseQueue.size() == 1) {
           processResponse(call.connection.responseQueue, true);
@@ -2314,15 +2329,11 @@ public abstract class Server {
           }
           CurCall.set(null);
           synchronized (call.connection.responseQueue) {
-            // setupResponse() needs to be sync'ed together with 
-            // responder.doResponse() since setupResponse may use
-            // SASL to encrypt response data and SASL enforces
-            // its own message ordering.
-            setupResponse(buf, call, returnStatus, detailedErr, 
+            setupResponse(buf, call, returnStatus, detailedErr,
                 value, errorClass, error);
-            
-            // Discard the large buf and reset it back to smaller size 
-            // to free up heap
+
+            // Discard the large buf and reset it back to smaller size
+            // to free up heap.
             if (buf.size() > maxRespSize) {
               LOG.warn("Large response size " + buf.size() + " for call "
                   + call.toString());
@@ -2582,9 +2593,6 @@ public abstract class Server {
       out.writeInt(fullLength);
       header.writeDelimitedTo(out);
     }
-    if (call.connection.useWrap) {
-      wrapWithSasl(responseBuf, call);
-    }
     call.setResponse(ByteBuffer.wrap(responseBuf.toByteArray()));
   }
   
@@ -2612,10 +2620,6 @@ public abstract class Server {
     out.writeInt(OLD_VERSION_FATAL_STATUS);   // write FATAL_STATUS
     WritableUtils.writeString(out, errorClass);
     WritableUtils.writeString(out, error);
-
-    if (call.connection.useWrap) {
-      wrapWithSasl(response, call);
-    }
     call.setResponse(ByteBuffer.wrap(response.toByteArray()));
   }
   
@@ -2623,7 +2627,7 @@ public abstract class Server {
   private static void wrapWithSasl(ByteArrayOutputStream response, Call call)
       throws IOException {
     if (call.connection.saslServer != null) {
-      byte[] token = response.toByteArray();
+      byte[] token = call.rpcResponse.array();
       // synchronization may be needed since there can be multiple Handler
       // threads using saslServer to wrap responses.
       synchronized (call.connection.saslServer) {

+ 113 - 1
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java

@@ -40,9 +40,21 @@ import java.security.PrivilegedExceptionAction;
 import java.security.Security;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Pattern;
 
 import javax.security.auth.callback.Callback;
@@ -65,6 +77,7 @@ import org.apache.hadoop.fs.CommonConfigurationKeys;
 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.ipc.Client.ConnectionId;
+import org.apache.hadoop.ipc.Server.Call;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.KerberosInfo;
 import org.apache.hadoop.security.SaslInputStream;
@@ -78,6 +91,7 @@ import org.apache.hadoop.security.SecurityInfo;
 import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.TestUserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
 import org.apache.hadoop.security.token.SecretManager;
 import org.apache.hadoop.security.token.SecretManager.InvalidToken;
 import org.apache.hadoop.security.token.Token;
@@ -85,6 +99,7 @@ import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.security.token.TokenInfo;
 import org.apache.hadoop.security.token.TokenSelector;
 import org.apache.log4j.Level;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -294,10 +309,13 @@ public class TestSaslRPC {
   public interface TestSaslProtocol extends TestRPC.TestProtocol {
     public AuthMethod getAuthMethod() throws IOException;
     public String getAuthUser() throws IOException;
+    public String echoPostponed(String value) throws IOException;
+    public void sendPostponed() throws IOException;
   }
-  
+
   public static class TestSaslImpl extends TestRPC.TestImpl implements
       TestSaslProtocol {
+    private List<Call> postponedCalls = new ArrayList<Call>();
     @Override
     public AuthMethod getAuthMethod() throws IOException {
       return UserGroupInformation.getCurrentUser()
@@ -307,6 +325,21 @@ public class TestSaslRPC {
     public String getAuthUser() throws IOException {
       return UserGroupInformation.getCurrentUser().getUserName();
     }
+    @Override
+    public String echoPostponed(String value) {
+      Call call = Server.getCurCall().get();
+      call.postponeResponse();
+      postponedCalls.add(call);
+      return value;
+    }
+    @Override
+    public void sendPostponed() throws IOException {
+      Collections.shuffle(postponedCalls);
+      for (Call call : postponedCalls) {
+        call.sendResponse();
+      }
+      postponedCalls.clear();
+    }
   }
 
   public static class CustomSecurityInfo extends SecurityInfo {
@@ -844,6 +877,85 @@ public class TestSaslRPC {
     assertAuthEquals(KrbFailed,    getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID));
   }
 
+  // ensure that for all qop settings, client can handle postponed rpc
+  // responses.  basically ensures that the rpc server isn't encrypting
+  // and queueing the responses out of order.
+  @Test(timeout=10000)
+  public void testSaslResponseOrdering() throws Exception {
+    SecurityUtil.setAuthenticationMethod(
+        AuthenticationMethod.TOKEN, conf);
+    UserGroupInformation.setConfiguration(conf);
+
+    TestTokenSecretManager sm = new TestTokenSecretManager();
+    Server server = new RPC.Builder(conf)
+        .setProtocol(TestSaslProtocol.class)
+        .setInstance(new TestSaslImpl()).setBindAddress(ADDRESS).setPort(0)
+        .setNumHandlers(1) // prevents ordering issues when unblocking calls.
+        .setVerbose(true)
+        .setSecretManager(sm)
+        .build();
+    server.start();
+    try {
+      final InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      final UserGroupInformation clientUgi =
+          UserGroupInformation.createRemoteUser("client");
+      clientUgi.setAuthenticationMethod(AuthenticationMethod.TOKEN);
+
+      TestTokenIdentifier tokenId = new TestTokenIdentifier(
+          new Text(clientUgi.getUserName()));
+      Token<?> token = new Token<TestTokenIdentifier>(tokenId, sm);
+      SecurityUtil.setTokenService(token, addr);
+      clientUgi.addToken(token);
+      clientUgi.doAs(new PrivilegedExceptionAction<Void>() {
+        @Override
+        public Void run() throws Exception {
+          final TestSaslProtocol proxy = RPC.getProxy(TestSaslProtocol.class,
+              TestSaslProtocol.versionID, addr, conf);
+          final ExecutorService executor = Executors.newCachedThreadPool();
+          final AtomicInteger count = new AtomicInteger();
+          try {
+            // queue up a bunch of futures for postponed calls serviced
+            // in a random order.
+            Future<?>[] futures = new Future<?>[10];
+            for (int i=0; i < futures.length; i++) {
+              futures[i] = executor.submit(new Callable<Void>(){
+                @Override
+                public Void call() throws Exception {
+                  String expect = "future"+count.getAndIncrement();
+                  String answer = proxy.echoPostponed(expect);
+                  assertEquals(expect, answer);
+                  return null;
+                }
+              });
+              try {
+                // ensures the call is initiated and the response is blocked.
+                futures[i].get(100, TimeUnit.MILLISECONDS);
+              } catch (TimeoutException te) {
+                continue; // expected.
+              }
+              Assert.fail("future"+i+" did not block");
+            }
+            // triggers responses to be unblocked in a random order.  having
+            // only 1 handler ensures that the prior calls are already
+            // postponed.  1 handler also ensures that this call will
+            // timeout if the postponing doesn't work (ie. free up handler)
+            proxy.sendPostponed();
+            for (int i=0; i < futures.length; i++) {
+              LOG.info("waiting for future"+i);
+              futures[i].get();
+            }
+          } finally {
+            RPC.stopProxy(proxy);
+            executor.shutdownNow();
+          }
+          return null;
+        }
+      });
+    } finally {
+      server.stop();
+    }
+  }
+
 
   // test helpers