소스 검색

ZOOKEEPER-4804. Use daemon threads for Netty client (#2142)

Co-authored-by: tison <wander4096@gmail.com>
Istvan Toth 1 년 전
부모
커밋
803c485db9

+ 23 - 9
zookeeper-server/src/main/java/org/apache/zookeeper/common/NettyUtils.java

@@ -28,6 +28,7 @@ import io.netty.channel.socket.ServerSocketChannel;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioServerSocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.util.concurrent.DefaultThreadFactory;
 import java.net.InetAddress;
 import java.net.NetworkInterface;
 import java.net.SocketException;
@@ -35,6 +36,7 @@ import java.util.Collections;
 import java.util.Enumeration;
 import java.util.HashSet;
 import java.util.Set;
+import java.util.concurrent.ThreadFactory;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -43,10 +45,22 @@ import org.slf4j.LoggerFactory;
  */
 public class NettyUtils {
 
+    public static final String THREAD_POOL_NAME_PREFIX = "zkNetty-";
     private static final Logger LOG = LoggerFactory.getLogger(NettyUtils.class);
-
     private static final int DEFAULT_INET_ADDRESS_COUNT = 1;
 
+    /**
+     * Returns a ThreadFactory which generates daemon threads, and uses
+     * the passed class's name to generate the thread names.
+     *
+     * @param clazz Class name to use for generating thread names
+     * @return Netty DefaultThreadFactory configured to create daemon threads
+     */
+    private static ThreadFactory createThreadFactory(String clazz) {
+        final String poolName = THREAD_POOL_NAME_PREFIX + clazz;
+        return new DefaultThreadFactory(poolName, true);
+    }
+
     /**
      * If {@link Epoll#isAvailable()} <code>== true</code>, returns a new
      * {@link EpollEventLoopGroup}, otherwise returns a new
@@ -55,11 +69,7 @@ public class NettyUtils {
      * @return a new {@link EventLoopGroup}.
      */
     public static EventLoopGroup newNioOrEpollEventLoopGroup() {
-        if (Epoll.isAvailable()) {
-            return new EpollEventLoopGroup();
-        } else {
-            return new NioEventLoopGroup();
-        }
+        return newNioOrEpollEventLoopGroup(0);
     }
 
     /**
@@ -72,9 +82,13 @@ public class NettyUtils {
      */
     public static EventLoopGroup newNioOrEpollEventLoopGroup(int nThreads) {
         if (Epoll.isAvailable()) {
-            return new EpollEventLoopGroup(nThreads);
+            final String clazz = EpollEventLoopGroup.class.getSimpleName();
+            final ThreadFactory factory = createThreadFactory(clazz);
+            return new EpollEventLoopGroup(nThreads, factory);
         } else {
-            return new NioEventLoopGroup(nThreads);
+            final String clazz = NioEventLoopGroup.class.getSimpleName();
+            final ThreadFactory factory = createThreadFactory(clazz);
+            return new NioEventLoopGroup(nThreads, factory);
         }
     }
 
@@ -145,7 +159,7 @@ public class NettyUtils {
                 }
             }
             LOG.debug("Detected {} local network addresses: {}", validInetAddresses.size(), validInetAddresses);
-            return validInetAddresses.size() > 0 ? validInetAddresses.size() : DEFAULT_INET_ADDRESS_COUNT;
+            return !validInetAddresses.isEmpty() ? validInetAddresses.size() : DEFAULT_INET_ADDRESS_COUNT;
         } catch (SocketException ex) {
             LOG.warn("Failed to list all network interfaces, assuming 1", ex);
             return DEFAULT_INET_ADDRESS_COUNT;

+ 36 - 1
zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnTest.java

@@ -41,23 +41,28 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.ProtocolException;
 import java.nio.charset.StandardCharsets;
+import java.util.List;
 import java.util.Random;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.zookeeper.AsyncCallback.DataCallback;
+import org.apache.zookeeper.ClientCnxnSocketNetty;
 import org.apache.zookeeper.CreateMode;
 import org.apache.zookeeper.KeeperException;
 import org.apache.zookeeper.ZooDefs.Ids;
 import org.apache.zookeeper.ZooKeeper;
+import org.apache.zookeeper.client.ZKClientConfig;
 import org.apache.zookeeper.common.ClientX509Util;
+import org.apache.zookeeper.common.NettyUtils;
 import org.apache.zookeeper.data.Stat;
 import org.apache.zookeeper.server.quorum.BufferStats;
 import org.apache.zookeeper.server.quorum.LeaderZooKeeperServer;
 import org.apache.zookeeper.test.ClientBase;
 import org.apache.zookeeper.test.SSLAuthTest;
 import org.apache.zookeeper.test.TestByteBufAllocator;
+import org.apache.zookeeper.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -323,6 +328,37 @@ public class NettyServerCnxnTest extends ClientBase {
         runEnableDisableThrottling(false, false);
     }
 
+    @Test
+    public void testNettyUsesDaemonThreads() throws Exception {
+        assertTrue(serverFactory instanceof NettyServerCnxnFactory,
+                "Didn't instantiate ServerCnxnFactory with NettyServerCnxnFactory!");
+
+        // Use Netty in the client to check the threads on both the client and server side
+        System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, ClientCnxnSocketNetty.class.getName());
+        try {
+            final ZooKeeperServer zkServer = serverFactory.getZooKeeperServer();
+            try (ZooKeeper zk = createClient()) {
+                final String path = "/a";
+                // make sure connection is established
+                zk.create(path, "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
+
+                List<Thread> threads = TestUtils.getAllThreads();
+                boolean foundThread = false;
+                for (Thread t : threads) {
+                    if (t.getName().startsWith(NettyUtils.THREAD_POOL_NAME_PREFIX)) {
+                        foundThread = true;
+                        assertTrue(t.isDaemon(), "All Netty threads started by ZK must daemon threads");
+                    }
+                }
+                assertTrue(foundThread, "Did not find any Netty ZK Threads");
+            } finally {
+                zkServer.shutdown();
+            }
+        } finally {
+            System.clearProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET);
+        }
+    }
+
     private void runEnableDisableThrottling(boolean secure, boolean randomDisableEnable) throws Exception {
         ClientX509Util x509Util = null;
         if (secure) {
@@ -432,5 +468,4 @@ public class NettyServerCnxnTest extends ClientBase {
             }
         }
     }
-
 }

+ 27 - 0
zookeeper-server/src/test/java/org/apache/zookeeper/test/TestUtils.java

@@ -21,6 +21,10 @@ package org.apache.zookeeper.test;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.fail;
 import java.io.File;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.apache.zookeeper.WatchedEvent;
 
 /**
@@ -71,4 +75,27 @@ public class TestUtils {
         assertEquals(expected.getPath(), actual.getPath());
         assertEquals(expected.getZxid(), actual.getZxid());
     }
+
+    /**
+     * Return all threads
+     *
+     * Code based on commons-lang3 ThreadUtils
+     *
+     * @return all active threads
+     */
+    public static List<Thread> getAllThreads() {
+        ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
+        while (threadGroup != null && threadGroup.getParent() != null) {
+            threadGroup = threadGroup.getParent();
+        }
+
+        int count = threadGroup.activeCount();
+        Thread[] threads;
+        do {
+            threads = new Thread[count + count / 2 + 1]; //slightly grow the array size
+            count = threadGroup.enumerate(threads, true);
+            //return value of enumerate() must be strictly less than the array size according to javadoc
+        } while (count >= threads.length);
+        return Collections.unmodifiableList(Stream.of(threads).limit(count).collect(Collectors.toList()));
+    }
 }