Browse Source

ZOOKEEPER-4508: Expire session in client side to avoid endless connection loss

Reviewers: anmolnar
Author: kezhuw
Closes #2058 from kezhuw/ZOOKEEPER-4508-client-side-session-expiration
Kezhu Wang 7 tháng trước cách đây
mục cha
commit
890061841f

+ 28 - 6
zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxn.java

@@ -165,6 +165,8 @@ public class ClientCnxn {
 
     private int readTimeout;
 
+    private int expirationTimeout;
+
     private final int sessionTimeout;
 
     private final ZKWatchManager watchManager;
@@ -411,6 +413,7 @@ public class ClientCnxn {
 
         this.connectTimeout = sessionTimeout / hostProvider.size();
         this.readTimeout = sessionTimeout * 2 / 3;
+        this.expirationTimeout = sessionTimeout * 4 / 3;
 
         this.sendThread = new SendThread(clientCnxnSocket);
         this.eventThread = new EventThread();
@@ -803,6 +806,12 @@ public class ClientCnxn {
 
     }
 
+    private static class ConnectionTimeoutException extends IOException {
+        public ConnectionTimeoutException(String message) {
+            super(message);
+        }
+    }
+
     private static class SessionTimeoutException extends IOException {
 
         private static final long serialVersionUID = 824482094072071178L;
@@ -1143,7 +1152,7 @@ public class ClientCnxn {
                         startConnect(serverAddress);
                         // Update now to start the connection timer right after we make a connection attempt
                         clientCnxnSocket.updateNow();
-                        clientCnxnSocket.updateLastSendAndHeard();
+                        clientCnxnSocket.updateLastSend();
                     }
 
                     if (state.isConnected()) {
@@ -1181,16 +1190,24 @@ public class ClientCnxn {
                         }
                         to = readTimeout - clientCnxnSocket.getIdleRecv();
                     } else {
-                        to = connectTimeout - clientCnxnSocket.getIdleRecv();
+                        to = connectTimeout - clientCnxnSocket.getIdleSend();
                     }
 
-                    if (to <= 0) {
+                    int expiration = expirationTimeout - clientCnxnSocket.getIdleRecv();
+                    if (expiration <= 0) {
                         String warnInfo = String.format(
                             "Client session timed out, have not heard from server in %dms for session id 0x%s",
                             clientCnxnSocket.getIdleRecv(),
                             Long.toHexString(sessionId));
                         LOG.warn(warnInfo);
+                        changeZkState(States.CLOSED);
                         throw new SessionTimeoutException(warnInfo);
+                    } else if (to <= 0) {
+                        String warnInfo = String.format(
+                            "Client connection timed out, have not heard from server in %dms for session id 0x%s",
+                            clientCnxnSocket.getIdleRecv(),
+                            Long.toHexString(sessionId));
+                        throw new ConnectionTimeoutException(warnInfo);
                     }
                     if (state.isConnected()) {
                         //1000(1 second) is to prevent race condition missing to send the second ping
@@ -1235,7 +1252,7 @@ public class ClientCnxn {
                     } else {
                         LOG.warn(
                             "Session 0x{} for server {}, Closing socket connection. "
-                                + "Attempting reconnect except it is a SessionExpiredException.",
+                                + "Attempting reconnect except it is a SessionExpiredException or SessionTimeoutException.",
                             Long.toHexString(getSessionId()),
                             serverAddress,
                             e);
@@ -1256,7 +1273,12 @@ public class ClientCnxn {
             if (state.isAlive()) {
                 eventThread.queueEvent(new WatchedEvent(Event.EventType.None, Event.KeeperState.Disconnected, null));
             }
-            eventThread.queueEvent(new WatchedEvent(Event.EventType.None, Event.KeeperState.Closed, null));
+            if (closing) {
+                eventThread.queueEvent(new WatchedEvent(Event.EventType.None, KeeperState.Closed, null));
+            } else if (state == States.CLOSED) {
+                eventThread.queueEvent(new WatchedEvent(Event.EventType.None, KeeperState.Expired, null));
+            }
+            eventThread.queueEventOfDeath();
 
             Login l = loginRef.getAndSet(null);
             if (l != null) {
@@ -1274,7 +1296,6 @@ public class ClientCnxn {
                 eventThread.queueEvent(new WatchedEvent(Event.EventType.None, Event.KeeperState.Disconnected, null));
             }
             clientCnxnSocket.updateNow();
-            clientCnxnSocket.updateLastSendAndHeard();
         }
 
         private void pingRwServer() throws RWServerFoundException {
@@ -1374,6 +1395,7 @@ public class ClientCnxn {
             }
 
             readTimeout = negotiatedSessionTimeout * 2 / 3;
+            expirationTimeout = negotiatedSessionTimeout * 4 / 3;
             connectTimeout = negotiatedSessionTimeout / hostProvider.size();
             hostProvider.onConnected();
             sessionId = _sessionId;

+ 2 - 0
zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxnSocket.java

@@ -65,7 +65,9 @@ abstract class ClientCnxnSocket {
     protected ByteBuffer incomingBuffer = lenBuffer;
     protected final AtomicLong sentCount = new AtomicLong(0L);
     protected final AtomicLong recvCount = new AtomicLong(0L);
+    // Used for reactive timeout detection, say connection read timeout and session expiration timeout.
     protected long lastHeard;
+    // Used for proactive timeout detection, say ping timeout and connection establishment timeout.
     protected long lastSend;
     protected long now;
     protected ClientCnxn.SendThread sendThread;

+ 1 - 1
zookeeper-server/src/test/java/org/apache/zookeeper/test/ReconfigTest.java

@@ -804,7 +804,7 @@ public class ReconfigTest extends ZKTestCase implements DataCallback {
                 Thread.sleep(1000);
                 zkArr[serverIndex].setData("/test", "teststr".getBytes(), -1);
                 fail("New client connected to new client port!");
-            } catch (KeeperException.ConnectionLossException e) {
+            } catch (KeeperException.ConnectionLossException | KeeperException.SessionExpiredException e) {
                 // Exception is expected
             }
 

+ 21 - 14
zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTest.java

@@ -25,6 +25,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 import java.io.File;
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
@@ -90,10 +91,18 @@ public class SessionTest extends ZKTestCase {
     private static class CountdownWatcher implements Watcher {
 
         volatile CountDownLatch clientConnected = new CountDownLatch(1);
+        final CountDownLatch sessionTerminated = new CountDownLatch(1);
 
         public void process(WatchedEvent event) {
-            if (event.getState() == KeeperState.SyncConnected) {
-                clientConnected.countDown();
+            switch (event.getState()) {
+                case SyncConnected:
+                    clientConnected.countDown();
+                    break;
+                case AuthFailed:
+                case Expired:
+                case Closed:
+                    sessionTerminated.countDown();
+                    break;
             }
         }
 
@@ -274,17 +283,15 @@ public class SessionTest extends ZKTestCase {
         // shutdown the server
         serverFactory.shutdown();
 
-        try {
-            Thread.sleep(10000);
-        } catch (InterruptedException e) {
-            // ignore
-        }
+        watcher.sessionTerminated.await();
 
-        // verify that the size is just 2 - ie connect then disconnect
-        // if the client attempts reconnect and we are not handling current
-        // state correctly (ie eventing on duplicate disconnects) then we'll
-        // see a disconnect for each failed connection attempt
-        assertEquals(2, watcher.states.size());
+        // verify that there is no duplicated disconnected event.
+        List<KeeperState> states = Arrays.asList(
+                KeeperState.SyncConnected,
+                KeeperState.Disconnected,
+                KeeperState.Expired
+        );
+        assertEquals(states, watcher.states);
 
         zk.close();
     }
@@ -319,11 +326,11 @@ public class SessionTest extends ZKTestCase {
 
     private class DupWatcher extends CountdownWatcher {
 
-        public List<WatchedEvent> states = new LinkedList<>();
+        public List<KeeperState> states = new LinkedList<>();
         public void process(WatchedEvent event) {
             super.process(event);
             if (event.getType() == EventType.None) {
-                states.add(event);
+                states.add(event.getState());
             }
         }
 

+ 104 - 0
zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTimeoutTest.java

@@ -20,9 +20,15 @@ package org.apache.zookeeper.test;
 
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 import java.io.IOException;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import org.apache.zookeeper.CreateMode;
@@ -48,6 +54,30 @@ public class SessionTimeoutTest extends ClientBase {
         zk = createClient();
     }
 
+    private static class BusyServer implements AutoCloseable {
+        private final ServerSocket server;
+        private final Socket client;
+
+        public BusyServer() throws IOException {
+            this.server = new ServerSocket(0, 1);
+            this.client = new Socket("127.0.0.1", server.getLocalPort());
+        }
+
+        public int getLocalPort() {
+            return server.getLocalPort();
+        }
+
+        public String getHostPort() {
+            return String.format("127.0.0.1:%d", getLocalPort());
+        }
+
+        @Override
+        public void close() throws Exception {
+            client.close();
+            server.close();
+        }
+    }
+
     @Test
     public void testSessionExpiration() throws InterruptedException, KeeperException {
         final CountDownLatch expirationLatch = new CountDownLatch(1);
@@ -72,6 +102,80 @@ public class SessionTimeoutTest extends ClientBase {
         assertTrue(gotException);
     }
 
+    @Test
+    public void testSessionRecoveredAfterMultipleFailedAttempts() throws Exception {
+        // stop client also to gain less distraction
+        zk.close();
+
+        try (BusyServer busyServer = new BusyServer()) {
+            List<String> servers = Arrays.asList(
+                    busyServer.getHostPort(),
+                    busyServer.getHostPort(),
+                    hostPort,
+                    busyServer.getHostPort(),
+                    busyServer.getHostPort(),
+                    busyServer.getHostPort()
+                    );
+            String connectString = String.join(",", servers);
+
+            zk = createClient(new CountdownWatcher(), connectString);
+            stopServer();
+
+            // Wait beyond connectTimeout but not sessionTimeout.
+            Thread.sleep(zk.getSessionTimeout() / 2);
+
+            CompletableFuture<Void> connected = new CompletableFuture<>();
+            zk.register(event -> {
+                if (event.getState() == Watcher.Event.KeeperState.SyncConnected) {
+                    connected.complete(null);
+                } else {
+                    connected.completeExceptionally(new KeeperException.SessionExpiredException());
+                }
+            });
+
+            startServer();
+            connected.join();
+        }
+    }
+
+    @Test
+    public void testSessionExpirationAfterAllServerDown() throws Exception {
+        // stop client also to gain less distraction
+        zk.close();
+
+        // small connection timeout to gain quick ci feedback
+        int sessionTimeout = 3000;
+        CompletableFuture<Void> expired = new CompletableFuture<>();
+        zk = createClient(new CountdownWatcher(), hostPort, sessionTimeout);
+        zk.register(event -> {
+            if (event.getState() == Watcher.Event.KeeperState.Expired) {
+                expired.complete(null);
+            }
+        });
+        stopServer();
+        expired.join();
+        assertThrows(KeeperException.SessionExpiredException.class, () -> zk.exists("/", null));
+    }
+
+    @Test
+    public void testSessionExpirationWhenNoServerUp() throws Exception {
+        // stop client also to gain less distraction
+        zk.close();
+
+        stopServer();
+
+        // small connection timeout to gain quick ci feedback
+        int sessionTimeout = 3000;
+        CompletableFuture<Void> expired = new CompletableFuture<>();
+        new TestableZooKeeper(hostPort, sessionTimeout, event -> {
+            if (event.getState() == Watcher.Event.KeeperState.Expired) {
+                expired.complete(null);
+            }
+        });
+        expired.join();
+        assertThrows(KeeperException.SessionExpiredException.class, () -> zk.exists("/", null));
+    }
+
     @Test
     public void testQueueEvent() throws InterruptedException, KeeperException {
         final CountDownLatch eventLatch = new CountDownLatch(1);