Selaa lähdekoodia

ZOOKEEPER-3473: Improving successful TLS handshake throughput with concurrent control

When there are lots of clients trying to re-establish sessions, there might be lots of half finished handshake timed out, and those failed ones keep reconnecting to another server and restarting the handshake from beginning again, which caused herd effect.

And the number of total ZK sessions could be supported within session timeout are impacted a lot after enabling TLS.

To improve the throughput, we added the TLS concurrent control to reduce the herd effect, and from out benchmark this doubled the sessions we could support within session timeout.

E2E test result:

Tested performance and correctness from E2E. For correctness, tested both secure and insecure
connections, the outstandingHandshakeNum will go to 0 eventually.

For performance, tested with 110k sessions with 10s session timeout, there is no session expire when leader election triggered, while before it can only support 50k sessions.

Author: Fangmin Lyu <fangmin@apache.org>

Reviewers: Enrico Olivelli <eolivelli@apache.org>, Andor Molnar <andor@apache.org>

Closes #1027 from lvfangmin/ZOOKEEPER-3473
Fangmin Lyu 5 vuotta sitten
vanhempi
commit
804095c060

+ 8 - 0
zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md

@@ -998,6 +998,14 @@ property, when available, is noted below.
     **New in 3.6.0:**
     The size threshold after which a request is considered a large request. If it is -1, then all requests are considered small, effectively turning off large request throttling. The default is -1.
 
+* *outstandingHandshake.limit* 
+    (Jave system property only: **zookeeper.netty.server.outstandingHandshake.limit**)
+    The maximum in-flight TLS handshake connections could have in ZooKeeper, 
+    the connections exceed this limit will be rejected before starting handshake. 
+    This setting doesn't limit the max TLS concurrency, but helps avoid herd 
+    effect due to TLS handshake timeout when there are too many in-flight TLS 
+    handshakes. Set it to something like 250 is good enough to avoid herd effect.
+
 <a name="sc_clusterOptions"></a>
 
 #### Cluster Options

+ 15 - 0
zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java

@@ -71,6 +71,14 @@ public class NettyServerCnxn extends ServerCnxn {
 
     public int readIssuedAfterReadComplete;
 
+    private volatile HandshakeState handshakeState = HandshakeState.NONE;
+
+    public enum HandshakeState {
+        NONE,
+        STARTED,
+        FINISHED
+    }
+
     NettyServerCnxn(Channel channel, ZooKeeperServer zks, NettyServerCnxnFactory factory) {
         super(zks);
         this.channel = channel;
@@ -631,4 +639,11 @@ public class NettyServerCnxn extends ServerCnxn {
         return 0;
     }
 
+    public void setHandshakeState(HandshakeState state) {
+        this.handshakeState = state;
+    }
+
+    public HandshakeState getHandshakeState() {
+        return this.handshakeState;
+    }
 }

+ 49 - 0
zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java

@@ -69,6 +69,7 @@ import org.apache.zookeeper.common.NettyUtils;
 import org.apache.zookeeper.common.SSLContextAndOptions;
 import org.apache.zookeeper.common.X509Exception;
 import org.apache.zookeeper.common.X509Exception.SSLContextException;
+import org.apache.zookeeper.server.NettyServerCnxn.HandshakeState;
 import org.apache.zookeeper.server.auth.ProviderRegistry;
 import org.apache.zookeeper.server.auth.X509AuthenticationProvider;
 import org.apache.zookeeper.server.quorum.QuorumPeerConfig;
@@ -93,6 +94,18 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
      */
     private static final byte TLS_HANDSHAKE_RECORD_TYPE = 0x16;
 
+    private final AtomicInteger outstandingHandshake = new AtomicInteger();
+    public static final String OUTSTANDING_HANDSHAKE_LIMIT = "zookeeper.netty.server.outstandingHandshake.limit";
+    private int outstandingHandshakeLimit;
+    private boolean handshakeThrottlingEnabled;
+
+    public void setOutstandingHandshakeLimit(int limit) {
+        outstandingHandshakeLimit = limit;
+        handshakeThrottlingEnabled = (secure || shouldUsePortUnification) && outstandingHandshakeLimit > 0;
+        LOG.info("handshakeThrottlingEnabled = {}, {} = {}",
+                handshakeThrottlingEnabled, OUTSTANDING_HANDSHAKE_LIMIT, outstandingHandshakeLimit);
+    }
+
     private final ServerBootstrap bootstrap;
     private Channel parentChannel;
     private final ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns", new DefaultEventExecutor());
@@ -164,6 +177,8 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
         protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
             NettyServerCnxn cnxn = Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get());
             LOG.debug("creating plaintext handler for session {}", cnxn.getSessionId());
+            // Mark handshake finished if it's a insecure cnxn
+            updateHandshakeCountIfStarted(cnxn);
             allChannels.add(context.channel());
             addCnxn(cnxn);
             return super.newNonSslHandler(context);
@@ -171,6 +186,13 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
 
     }
 
+    private void updateHandshakeCountIfStarted(NettyServerCnxn cnxn) {
+        if (cnxn != null && cnxn.getHandshakeState() == HandshakeState.STARTED) {
+            cnxn.setHandshakeState(HandshakeState.FINISHED);
+            outstandingHandshake.addAndGet(-1);
+        }
+    }
+
     /**
      * This is an inner class since we need to extend ChannelDuplexHandler, but
      * NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner
@@ -202,6 +224,23 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
             NettyServerCnxn cnxn = new NettyServerCnxn(channel, zkServer, NettyServerCnxnFactory.this);
             ctx.channel().attr(CONNECTION_ATTRIBUTE).set(cnxn);
 
+            if (handshakeThrottlingEnabled) {
+                // Favor to check and throttling even in dual mode which
+                // accepts both secure and insecure connections, since
+                // it's more efficient than throttling when we know it's
+                // a secure connection in DualModeSslHandler.
+                //
+                // From benchmark, this reduced around 15% reconnect time.
+                int outstandingHandshakesNum = outstandingHandshake.addAndGet(1);
+                if (outstandingHandshakesNum > outstandingHandshakeLimit) {
+                    outstandingHandshake.addAndGet(-1);
+                    channel.close();
+                    ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED.add(1);
+                } else {
+                    cnxn.setHandshakeState(HandshakeState.STARTED);
+                }
+            }
+
             if (secure) {
                 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
                 Future<Channel> handshakeFuture = sslHandler.handshakeFuture();
@@ -224,6 +263,7 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
                 if (LOG.isTraceEnabled()) {
                     LOG.trace("Channel inactive caused close {}", cnxn);
                 }
+                updateHandshakeCountIfStarted(cnxn);
                 cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_DISCONNECTED);
             }
         }
@@ -234,6 +274,7 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
             NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null);
             if (cnxn != null) {
                 LOG.debug("Closing {}", cnxn);
+                updateHandshakeCountIfStarted(cnxn);
                 cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_CLOSED_EXCEPTION);
             }
         }
@@ -339,6 +380,8 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
          * Only allow the connection to stay open if certificate passes auth
          */
         public void operationComplete(Future<Channel> future) {
+            updateHandshakeCountIfStarted(cnxn);
+
             if (future.isSuccess()) {
                 LOG.debug("Successful handshake with session 0x{}", Long.toHexString(cnxn.getSessionId()));
                 SSLEngine eng = sslHandler.engine();
@@ -451,6 +494,8 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
         this.advancedFlowControlEnabled = Boolean.getBoolean(NETTY_ADVANCED_FLOW_CONTROL);
         LOG.info("{} = {}", NETTY_ADVANCED_FLOW_CONTROL, this.advancedFlowControlEnabled);
 
+        setOutstandingHandshakeLimit(Integer.getInteger(OUTSTANDING_HANDSHAKE_LIMIT, -1));
+
         EventLoopGroup bossGroup = NettyUtils.newNioOrEpollEventLoopGroup(NettyUtils.getClientReachableLocalInetAddressCount());
         EventLoopGroup workerGroup = NettyUtils.newNioOrEpollEventLoopGroup();
         ServerBootstrap bootstrap = new ServerBootstrap().group(bossGroup, workerGroup)
@@ -756,4 +801,8 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
     public Channel getParentChannel() {
         return parentChannel;
     }
+
+    public int getOutstandingHandshakeNum() {
+        return outstandingHandshake.get();
+    }
 }

+ 3 - 0
zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java

@@ -229,6 +229,7 @@ public final class ServerMetrics {
         NETTY_QUEUED_BUFFER = metricsContext.getSummary("netty_queued_buffer_capacity", DetailLevel.BASIC);
 
         DIGEST_MISMATCHES_COUNT = metricsContext.getCounter("digest_mismatches_count");
+        TLS_HANDSHAKE_EXCEEDED = metricsContext.getCounter("tls_handshake_exceeded");
     }
 
     /**
@@ -441,6 +442,8 @@ public final class ServerMetrics {
     // txns to data tree.
     public final Counter DIGEST_MISMATCHES_COUNT;
 
+    public final Counter TLS_HANDSHAKE_EXCEEDED;
+
     private final MetricsProvider metricsProvider;
 
     public void resetAll() {

+ 8 - 0
zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java

@@ -1815,6 +1815,7 @@ public class ZooKeeperServer implements SessionExpirer, ServerStats.Provider {
         rootContext.registerGauge("max_client_response_size", stats.getClientResponseStats()::getMaxBufferSize);
         rootContext.registerGauge("min_client_response_size", stats.getClientResponseStats()::getMinBufferSize);
 
+        rootContext.registerGauge("outstanding_tls_handshake", this::getOutstandingHandshakeNum);
     }
 
     protected void unregisterMetrics() {
@@ -2074,4 +2075,11 @@ public class ZooKeeperServer implements SessionExpirer, ServerStats.Provider {
         return rv;
     }
 
+    public int getOutstandingHandshakeNum() {
+        if (serverCnxnFactory instanceof NettyServerCnxnFactory) {
+            return ((NettyServerCnxnFactory) serverCnxnFactory).getOutstandingHandshakeNum();
+        } else {
+            return 0;
+        }
+    }
 }

+ 98 - 1
zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java

@@ -19,11 +19,49 @@
 package org.apache.zookeeper.server;
 
 import java.net.InetSocketAddress;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.zookeeper.PortAssignment;
+import org.apache.zookeeper.WatchedEvent;
+import org.apache.zookeeper.Watcher;
+import org.apache.zookeeper.ZooKeeper;
+import org.apache.zookeeper.common.ClientX509Util;
+import org.apache.zookeeper.server.metric.SimpleCounter;
+import org.apache.zookeeper.test.ClientBase;
+import org.apache.zookeeper.test.SSLAuthTest;
+import org.hamcrest.Matchers;
 import org.junit.Assert;
 import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-public class NettyServerCnxnFactoryTest {
+
+public class NettyServerCnxnFactoryTest extends ClientBase {
+
+    private static final Logger LOG = LoggerFactory
+            .getLogger(NettyServerCnxnFactoryTest.class);
+
+    final LinkedBlockingQueue<ZooKeeper> zks = new LinkedBlockingQueue<ZooKeeper>();
+
+    @Override
+    public void setUp() throws Exception {
+        System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY,
+                "org.apache.zookeeper.server.NettyServerCnxnFactory");
+        super.setUp();
+    }
+
+    @Override
+    public void tearDown() throws Exception {
+        System.clearProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY);
+
+        // clean up
+        for (ZooKeeper zk : zks) {
+            zk.close();
+        }
+        super.tearDown();
+    }
 
     @Test
     public void testRebind() throws Exception {
@@ -58,4 +96,63 @@ public class NettyServerCnxnFactoryTest {
         Assert.assertTrue(factory.getParentChannel().isActive());
     }
 
+    @Test
+    public void testOutstandingHandshakeLimit() throws Exception {
+
+        SimpleCounter tlsHandshakeExceeded = (SimpleCounter) ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED;
+        tlsHandshakeExceeded.reset();
+        Assert.assertEquals(tlsHandshakeExceeded.get(), 0);
+
+        ClientX509Util x509Util = SSLAuthTest.setUpSecure();
+        NettyServerCnxnFactory factory = (NettyServerCnxnFactory) serverFactory;
+        factory.setSecure(true);
+        factory.setOutstandingHandshakeLimit(10);
+
+        int threadNum = 3;
+        int cnxnPerThread = 10;
+        Thread[] cnxnWorker = new Thread[threadNum];
+
+        AtomicInteger cnxnCreated = new AtomicInteger(0);
+        CountDownLatch latch = new CountDownLatch(1);
+
+        for (int i = 0; i < cnxnWorker.length; i++) {
+            cnxnWorker[i] = new Thread() {
+                @Override
+                public void run() {
+                    for (int i = 0; i < cnxnPerThread; i++) {
+                        try {
+                            zks.add(new ZooKeeper(hostPort, 3000, new Watcher() {
+                                @Override
+                                public void process(WatchedEvent event) {
+                                    int created = cnxnCreated.addAndGet(1);
+                                    if (created == threadNum * cnxnPerThread) {
+                                        latch.countDown();
+                                    }
+                                }
+                            }));
+                        } catch (Exception e) {
+                            LOG.info("Error while creating zk client", e);
+                        }
+                    }
+                }
+            };
+            cnxnWorker[i].start();
+        }
+
+        Assert.assertThat(latch.await(3, TimeUnit.SECONDS), Matchers.is(true));
+        LOG.info("created {} connections", threadNum * cnxnPerThread);
+
+        // Assert throttling not 0
+        long handshakeThrottledNum = tlsHandshakeExceeded.get();
+        LOG.info("TLS_HANDSHAKE_EXCEEDED: {}", handshakeThrottledNum);
+        Assert.assertThat("The number of handshake throttled should be "
+                + "greater than 0", handshakeThrottledNum, Matchers.greaterThan(0L));
+
+        // Assert there is no outstanding handshake anymore
+        int outstandingHandshakeNum = factory.getOutstandingHandshakeNum();
+        LOG.info("outstanding handshake is {}", outstandingHandshakeNum);
+        Assert.assertThat("The outstanding handshake number should be 0 "
+                + "after all cnxns established", outstandingHandshakeNum, Matchers.is(0));
+
+    }
 }

+ 25 - 1
zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java

@@ -155,7 +155,31 @@ public class CommandsTest extends ClientBase {
 
     @Test
     public void testMonitor() throws IOException, InterruptedException {
-        ArrayList<Field> fields = new ArrayList<>(Arrays.asList(new Field("version", String.class), new Field("avg_latency", Double.class), new Field("max_latency", Long.class), new Field("min_latency", Long.class), new Field("packets_received", Long.class), new Field("packets_sent", Long.class), new Field("num_alive_connections", Integer.class), new Field("outstanding_requests", Long.class), new Field("server_state", String.class), new Field("znode_count", Integer.class), new Field("watch_count", Integer.class), new Field("ephemerals_count", Integer.class), new Field("approximate_data_size", Long.class), new Field("open_file_descriptor_count", Long.class), new Field("max_file_descriptor_count", Long.class), new Field("last_client_response_size", Integer.class), new Field("max_client_response_size", Integer.class), new Field("min_client_response_size", Integer.class), new Field("uptime", Long.class), new Field("global_sessions", Long.class), new Field("local_sessions", Long.class), new Field("connection_drop_probability", Double.class)));
+        ArrayList<Field> fields = new ArrayList<>(Arrays.asList(
+                new Field("version", String.class),
+                new Field("avg_latency", Double.class),
+                new Field("max_latency", Long.class),
+                new Field("min_latency", Long.class),
+                new Field("packets_received", Long.class),
+                new Field("packets_sent", Long.class),
+                new Field("num_alive_connections", Integer.class),
+                new Field("outstanding_requests", Long.class),
+                new Field("server_state", String.class),
+                new Field("znode_count", Integer.class),
+                new Field("watch_count", Integer.class),
+                new Field("ephemerals_count", Integer.class),
+                new Field("approximate_data_size", Long.class),
+                new Field("open_file_descriptor_count", Long.class),
+                new Field("max_file_descriptor_count", Long.class),
+                new Field("last_client_response_size", Integer.class),
+                new Field("max_client_response_size", Integer.class),
+                new Field("min_client_response_size", Integer.class),
+                new Field("uptime", Long.class),
+                new Field("global_sessions", Long.class),
+                new Field("local_sessions", Long.class),
+                new Field("connection_drop_probability", Double.class),
+                new Field("outstanding_tls_handshake", Integer.class)
+        ));
         Map<String, Object> metrics = MetricsUtils.currentServerMetrics();
 
         for (String metric : metrics.keySet()) {