瀏覽代碼

HADOO-2789. Race condition in IPC Server Responder that could close
connections early. (Raghu Angadi)


git-svn-id: https://svn.apache.org/repos/asf/hadoop/core/trunk@620596 13f79535-47bb-0310-9956-ffa450edef68

Raghu Angadi 17 年之前
父節點
當前提交
be510a66e9
共有 3 個文件被更改,包括 243 次插入80 次删除
  1. 3 0
      CHANGES.txt
  2. 89 80
      src/java/org/apache/hadoop/ipc/Server.java
  3. 151 0
      src/test/org/apache/hadoop/ipc/TestIPCServerResponder.java

+ 3 - 0
CHANGES.txt

@@ -33,6 +33,9 @@ Release 0.16.1 - Unrelease
 
   BUG FIXES
 
+    HADOO-2789. Race condition in IPC Server Responder that could close
+                connections early. (Raghu Angadi)
+    
     HADOOP-2785. minor. Fix a typo in Datanode block verification 
                  (Raghu Angadi)
     

+ 89 - 80
src/java/org/apache/hadoop/ipc/Server.java

@@ -25,6 +25,8 @@ import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 
 import java.nio.ByteBuffer;
+import java.nio.channels.CancelledKeyException;
+import java.nio.channels.ClosedChannelException;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.Selector;
 import java.nio.channels.ServerSocketChannel;
@@ -293,15 +295,9 @@ public abstract class Server {
             } catch (Exception e) {return;}
           }
           if (c.timedOut(currentTime)) {
-            synchronized (connectionList) {
-              if (connectionList.remove(c))
-                numConnections--;
-            }
-            try {
-              if (LOG.isDebugEnabled())
-                LOG.debug(getName() + ": disconnecting client " + c.getHostAddress());
-              c.close();
-            } catch (Exception e) {}
+            if (LOG.isDebugEnabled())
+              LOG.debug(getName() + ": disconnecting client " + c.getHostAddress());
+            closeConnection(c);
             numNuked++;
             end--;
             c = null;
@@ -334,7 +330,6 @@ public abstract class Server {
                   doRead(key);
               }
             } catch (IOException e) {
-              key.cancel();
             }
             key = null;
           }
@@ -369,15 +364,9 @@ public abstract class Server {
       if (key != null) {
         Connection c = (Connection)key.attachment();
         if (c != null) {
-          synchronized (connectionList) {
-            if (connectionList.remove(c))
-              numConnections--;
-          }
-          try {
-            if (LOG.isDebugEnabled())
-              LOG.debug(getName() + ": disconnecting client " + c.getHostAddress());
-            c.close();
-          } catch (Exception ex) {}
+          if (LOG.isDebugEnabled())
+            LOG.debug(getName() + ": disconnecting client " + c.getHostAddress());
+          closeConnection(c);
           c = null;
         }
       }
@@ -417,22 +406,15 @@ public abstract class Server {
       try {
         count = c.readAndProcess();
       } catch (Exception e) {
-        key.cancel();
         LOG.debug(getName() + ": readAndProcess threw exception " + e + ". Count of bytes read: " + count, e);
         count = -1; //so that the (count < 0) block is executed
       }
       if (count < 0) {
-        synchronized (connectionList) {
-          if (connectionList.remove(c))
-            numConnections--;
-        }
-        try {
-          if (LOG.isDebugEnabled())
-            LOG.debug(getName() + ": disconnecting client " + 
-                      c.getHostAddress() + ". Number of active connections: "+
-                      numConnections);
-          c.close();
-        } catch (Exception e) {}
+        if (LOG.isDebugEnabled())
+          LOG.debug(getName() + ": disconnecting client " + 
+                    c.getHostAddress() + ". Number of active connections: "+
+                    numConnections);
+        closeConnection(c);
         c = null;
       }
       else {
@@ -458,13 +440,13 @@ public abstract class Server {
   // Sends responses of RPC back to clients.
   private class Responder extends Thread {
     private Selector writeSelector;
-    private boolean pending;         // call waiting to be enqueued
+    private int pending;         // connections waiting to register
 
     Responder() throws IOException {
       this.setName("IPC Server Responder");
       this.setDaemon(true);
       writeSelector = Selector.open(); // create a selector
-      pending = false;
+      pending = 0;
     }
 
     @Override
@@ -474,13 +456,12 @@ public abstract class Server {
       long lastPurgeTime = 0;   // last check for old calls.
 
       while (running) {
-        SelectionKey key = null;
         try {
           waitPending();     // If a channel is being registered, wait.
           writeSelector.select(maxCallStartAge);
-          Iterator iter = writeSelector.selectedKeys().iterator();
+          Iterator<SelectionKey> iter = writeSelector.selectedKeys().iterator();
           while (iter.hasNext()) {
-            key = (SelectionKey)iter.next();
+            SelectionKey key = iter.next();
             iter.remove();
             try {
               if (key.isValid() && key.isWritable()) {
@@ -488,9 +469,7 @@ public abstract class Server {
               }
             } catch (IOException e) {
               LOG.info(getName() + ": doAsyncWrite threw exception " + e);
-              key.cancel();
             }
-            key = null;
           }
           long now = System.currentTimeMillis();
           if (now < lastPurgeTime + maxCallStartAge) {
@@ -504,7 +483,7 @@ public abstract class Server {
           LOG.debug("Checking for old call responses.");
           iter = writeSelector.keys().iterator();
           while (iter.hasNext()) {
-            key = (SelectionKey)iter.next();
+            SelectionKey key = iter.next();
             try {
               doPurge(key, now);
             } catch (IOException e) {
@@ -535,8 +514,20 @@ public abstract class Server {
       if (key.channel() != call.connection.channel) {
         throw new IOException("doAsyncWrite: bad channel");
       }
-      if (processResponse(call.connection.responseQueue)) {
-        key.cancel();          // remove item from selector.
+
+      synchronized(call.connection.responseQueue) {
+        if (processResponse(call.connection.responseQueue, false)) {
+          try {
+            key.interestOps(0);
+          } catch (CancelledKeyException e) {
+            /* The Listener/reader might have closed the socket.
+             * We don't explicitly cancel the key, so not sure if this will
+             * ever fire.
+             * This warning could be removed.
+             */
+            LOG.warn("Exception while changing ops : " + e);
+          }
+        }
       }
     }
 
@@ -553,11 +544,22 @@ public abstract class Server {
         LOG.info("doPurge: bad channel");
         return;
       }
+      boolean close = false;
       LinkedList<Call> responseQueue = call.connection.responseQueue;
       synchronized (responseQueue) {
-        Iterator iter = responseQueue.listIterator(0);
+        Iterator<Call> iter = responseQueue.listIterator(0);
         while (iter.hasNext()) {
-          call = (Call)iter.next();
+          call = iter.next();
+          if (call.response.position() > 0) {
+            /* We should probably use a different a different start time 
+             * than receivedTime. receivedTime starts when the RPC
+             * was first read.
+             * We have written a partial response. will close the
+             * connection for now.
+             */
+            close = true;
+            break;
+          }
           if (now > call.receivedTime + maxCallStartAge) {
             LOG.info(getName() + ", call " + call +
                      ": response discarded for being too old (" +
@@ -565,19 +567,18 @@ public abstract class Server {
             iter.remove();
           }
         }
-
-        // If all the calls for this channel were removed, then 
-        // remove this channel from the selector
-        if (responseQueue.size() == 0) {
-          key.cancel();
-        } 
+      }
+      
+      if (close) {
+        closeConnection(call.connection);
       }
     }
 
     // Processes one response. Returns true if there are no more pending
     // data for this channel.
     //
-    private boolean processResponse(LinkedList<Call> responseQueue) throws IOException {
+    private boolean processResponse(LinkedList<Call> responseQueue,
+                                    boolean inHandler) throws IOException {
       boolean error = true;
       boolean done = false;       // there is more data for this channel.
       int numElements = 0;
@@ -595,7 +596,6 @@ public abstract class Server {
           //
           // Extract the first call
           //
-          int numBytes = 0;
           call = responseQueue.removeFirst();
           SocketChannel channel = call.connection.channel;
           if (LOG.isDebugEnabled()) {
@@ -605,7 +605,10 @@ public abstract class Server {
           //
           // Send as much data as we can in the non-blocking fashion
           //
-          numBytes = channel.write(call.response);
+          int numBytes = channel.write(call.response);
+          if (numBytes < 0) {
+            return true;
+          }
           if (!call.response.hasRemaining()) {
             if (numElements == 1) {    // last call fully processes.
               done = true;             // no more data for this channel.
@@ -621,24 +624,27 @@ public abstract class Server {
             // If we were unable to write the entire response out, then 
             // insert in Selector queue. 
             //
-            call.connection.responseQueue.addFirst(call); 
-            setPending();
-            try {
-              // Wakeup the thread blocked on select, only then can the call 
-              // to channel.register() complete.
-              writeSelector.wakeup();
-              SelectionKey readKey = channel.register(writeSelector, 
-                                                      SelectionKey.OP_WRITE);
-              readKey.attach(call);
-            } finally {
-              clearPending();
+            call.connection.responseQueue.addFirst(call);
+            
+            if (inHandler) {
+              incPending();
+              try {
+                // Wakeup the thread blocked on select, only then can the call 
+                // to channel.register() complete.
+                writeSelector.wakeup();
+                channel.register(writeSelector, SelectionKey.OP_WRITE, call);
+              } catch (ClosedChannelException e) {
+                //Its ok. channel might be closed else where.
+                done = true;
+              } finally {
+                decPending();
+              }
             }
             if (LOG.isDebugEnabled()) {
               LOG.debug(getName() + ": responding to #" + call.id + " from " +
                         call.connection + " Wrote partial " + numBytes + 
                         " bytes.");
             }
-            done = false;             // this call not fully processed.
           }
           error = false;              // everything went off well
         }
@@ -646,11 +652,7 @@ public abstract class Server {
         if (error && call != null) {
           LOG.warn(getName()+", call " + call + ": output error");
           done = true;               // error. no more data for this channel.
-          synchronized (connectionList) {
-            if (connectionList.remove(call.connection))
-              numConnections--;
-          }
-          call.connection.close();
+          closeConnection(call.connection);
         }
       }
       return done;
@@ -663,22 +665,22 @@ public abstract class Server {
       synchronized (call.connection.responseQueue) {
         call.connection.responseQueue.addLast(call);
         if (call.connection.responseQueue.size() == 1) {
-          processResponse(call.connection.responseQueue);
+          processResponse(call.connection.responseQueue, true);
         }
       }
     }
 
-    private synchronized void setPending() {   // call waiting to be enqueued.
-      pending = true;
+    private synchronized void incPending() {   // call waiting to be enqueued.
+      pending++;
     }
 
-    private synchronized void clearPending() { // call done enqueueing.
-      pending = false;
+    private synchronized void decPending() { // call done enqueueing.
+      pending--;
       notify();
     }
 
     private synchronized void waitPending() throws InterruptedException {
-      while (pending) {
+      while (pending > 0) {
         wait();
       }
     }
@@ -691,7 +693,6 @@ public abstract class Server {
     private boolean headerRead = false;  //if the connection header that
                                          //follows version is read.
     private SocketChannel channel;
-    private SelectionKey key;
     private ByteBuffer data;
     private ByteBuffer dataLengthBuffer;
     private LinkedList<Call> responseQueue;
@@ -706,7 +707,6 @@ public abstract class Server {
 
     public Connection(SelectionKey key, SocketChannel channel, 
                       long lastContact) {
-      this.key = key;
       this.channel = channel;
       this.lastContact = lastContact;
       this.data = null;
@@ -847,7 +847,7 @@ public abstract class Server {
         
     }
 
-    private void close() throws IOException {
+    private synchronized void close() throws IOException {
       data = null;
       dataLengthBuffer = null;
       if (!channel.isOpen())
@@ -857,8 +857,6 @@ public abstract class Server {
         try {channel.close();} catch(Exception e) {}
       }
       try {socket.close();} catch(Exception e) {}
-      try {key.cancel();} catch(Exception e) {}
-      key = null;
     }
   }
 
@@ -980,6 +978,17 @@ public abstract class Server {
     responder = new Responder();
   }
 
+  private void closeConnection(Connection connection) {
+    synchronized (connectionList) {
+      if (connectionList.remove(connection))
+        numConnections--;
+    }
+    try {
+      connection.close();
+    } catch (IOException e) {
+    }
+  }
+  
   /** Sets the timeout used for network i/o. */
   public void setTimeout(int timeout) { this.timeout = timeout; }
 

+ 151 - 0
src/test/org/apache/hadoop/ipc/TestIPCServerResponder.java

@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * This test provokes partial writes in the server, which is 
+ * serving multiple clients.
+ */
+public class TestIPCServerResponder extends TestCase {
+
+  public static final Log LOG = 
+            LogFactory.getLog("org.apache.hadoop.ipc.TestIPCServerResponder");
+
+  private static Configuration conf = new Configuration();
+
+  public TestIPCServerResponder(final String name) {
+    super(name);
+  }
+
+  private static final Random RANDOM = new Random();
+
+  private static final String ADDRESS = "0.0.0.0";
+
+  private static final int BYTE_COUNT = 1024;
+  private static final byte[] BYTES = new byte[BYTE_COUNT];
+  static {
+    for (int i = 0; i < BYTE_COUNT; i++)
+      BYTES[i] = (byte) ('a' + (i % 26));
+  }
+
+  private static class TestServer extends Server {
+
+    private boolean sleep;
+
+    public TestServer(final int handlerCount, final boolean sleep) 
+                                              throws IOException {
+      super(ADDRESS, 0, BytesWritable.class, handlerCount, conf);
+      this.setTimeout(1000);
+      // Set the buffer size to half of the maximum parameter/result size 
+      // to force the socket to block
+      this.setSocketSendBufSize(BYTE_COUNT / 2);
+      this.sleep = sleep;
+    }
+
+    @Override
+    public Writable call(final Writable param, final long receivedTime) 
+                                               throws IOException {
+      if (sleep) {
+        try {
+          Thread.sleep(RANDOM.nextInt(20)); // sleep a bit
+        } catch (InterruptedException e) {}
+      }
+      return param;
+    }
+  }
+
+  private static class Caller extends Thread {
+
+    private Client client;
+    private int count;
+    private InetSocketAddress address;
+    private boolean failed;
+
+    public Caller(final Client client, final InetSocketAddress address, 
+                                       final int count) {
+      this.client = client;
+      this.address = address;
+      this.count = count;
+      client.setTimeout(1000);
+    }
+
+    @Override
+    public void run() {
+      for (int i = 0; i < count; i++) {
+        try {
+          int byteSize = RANDOM.nextInt(BYTE_COUNT);
+          byte[] bytes = new byte[byteSize];
+          System.arraycopy(BYTES, 0, bytes, 0, byteSize);
+          Writable param = new BytesWritable(bytes);
+          Writable value = client.call(param, address);
+          Thread.sleep(RANDOM.nextInt(20));
+        } catch (Exception e) {
+          LOG.fatal("Caught: " + e);
+          failed = true;
+        }
+      }
+    }
+  }
+
+  public void testServerResponder() throws Exception {
+    testServerResponder(10, true, 1, 10, 200);
+  }
+
+  public void testServerResponder(final int handlerCount, 
+                                  final boolean handlerSleep, 
+                                  final int clientCount,
+                                  final int callerCount,
+                                  final int callCount) throws Exception {
+    Server server = new TestServer(handlerCount, handlerSleep);
+    server.start();
+
+    InetSocketAddress address = server.getListenerAddress();
+    Client[] clients = new Client[clientCount];
+    for (int i = 0; i < clientCount; i++) {
+      clients[i] = new Client(BytesWritable.class, conf);
+    }
+
+    Caller[] callers = new Caller[callerCount];
+    for (int i = 0; i < callerCount; i++) {
+      callers[i] = new Caller(clients[i % clientCount], address, callCount);
+      callers[i].start();
+    }
+    for (int i = 0; i < callerCount; i++) {
+      callers[i].join();
+      assertFalse(callers[i].failed);
+    }
+    for (int i = 0; i < clientCount; i++) {
+      clients[i].stop();
+    }
+    server.stop();
+  }
+
+}