Ver código fonte

ZOOKEEPER-3988: rg.apache.zookeeper.server.NettyServerCnxn.receiveMessage throws NullPointerException

Modifications:
- prevent the NPE, the code that throws NPE is only to record some metrics for non TLS requests

Related to:
- apache/pulsar#11070
- https://github.com/pravega/zookeeper-operator/issues/393

Author: Enrico Olivelli <eolivelli@apache.org>

Reviewers: Nicolo² Boschi <boschi1997@gmail.com>, Andor Molnar <andor@apache.org>, Mate Szalay-Beko <symat@apache.org>

Closes #1798 from eolivelli/fix/ZOOKEEPER-3988-npe
Enrico Olivelli 3 anos atrás
pai
commit
957f8fc0af

+ 12 - 6
zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java

@@ -259,14 +259,20 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
                 allChannels.add(ctx.channel());
                 addCnxn(cnxn);
             }
+
             if (ctx.channel().pipeline().get(SslHandler.class) == null) {
-                SocketAddress remoteAddress = cnxn.getChannel().remoteAddress();
-                if (remoteAddress != null
-                        && !((InetSocketAddress) remoteAddress).getAddress().isLoopbackAddress()) {
-                    LOG.trace("NettyChannelHandler channelActive: remote={} local={}", remoteAddress, cnxn.getChannel().localAddress());
-                    zkServer.serverStats().incrementNonMTLSRemoteConnCount();
+                if (zkServer != null) {
+                    SocketAddress remoteAddress = cnxn.getChannel().remoteAddress();
+                    if (remoteAddress != null
+                            && !((InetSocketAddress) remoteAddress).getAddress().isLoopbackAddress()) {
+                        LOG.trace("NettyChannelHandler channelActive: remote={} local={}", remoteAddress, cnxn.getChannel().localAddress());
+                        zkServer.serverStats().incrementNonMTLSRemoteConnCount();
+                    } else {
+                        zkServer.serverStats().incrementNonMTLSLocalConnCount();
+                    }
                 } else {
-                    zkServer.serverStats().incrementNonMTLSLocalConnCount();
+                    LOG.trace("Opened non-TLS connection from {} but zkServer is not running",
+                            cnxn.getChannel().remoteAddress());
                 }
             }
         }

+ 18 - 10
zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnTest.java

@@ -174,9 +174,22 @@ public class NettyServerCnxnTest extends ClientBase {
         }
     }
 
-    @SuppressWarnings("unchecked")
     @Test
     public void testNonMTLSRemoteConn() throws Exception {
+        LeaderZooKeeperServer zks = mock(LeaderZooKeeperServer.class);
+        when(zks.isRunning()).thenReturn(true);
+        ServerStats.Provider providerMock = mock(ServerStats.Provider.class);
+        when(zks.serverStats()).thenReturn(new ServerStats(providerMock));
+        testNonMTLSRemoteConn(zks);
+    }
+
+    @Test
+    public void testNonMTLSRemoteConnZookKeeperServerNotReady() throws Exception {
+        testNonMTLSRemoteConn(null);
+    }
+
+    @SuppressWarnings("unchecked")
+    private void testNonMTLSRemoteConn(ZooKeeperServer zks) throws Exception {
         Channel channel = mock(Channel.class);
         ChannelId id = mock(ChannelId.class);
         ChannelFuture success = mock(ChannelFuture.class);
@@ -192,23 +205,18 @@ public class NettyServerCnxnTest extends ClientBase {
         when(channel.remoteAddress()).thenReturn(address);
         when(channel.id()).thenReturn(id);
         NettyServerCnxnFactory factory = new NettyServerCnxnFactory();
-        LeaderZooKeeperServer zks = mock(LeaderZooKeeperServer.class);
         factory.setZooKeeperServer(zks);
         Attribute atr = mock(Attribute.class);
         Mockito.doReturn(atr).when(channel).attr(
                 Mockito.any()
         );
         doNothing().when(atr).set(Mockito.any());
-
-        when(zks.isRunning()).thenReturn(true);
-
-        ServerStats.Provider providerMock = mock(ServerStats.Provider.class);
-        when(zks.serverStats()).thenReturn(new ServerStats(providerMock));
-
         factory.channelHandler.channelActive(context);
 
-        assertEquals(0, zks.serverStats().getNonMTLSLocalConnCount());
-        assertEquals(1, zks.serverStats().getNonMTLSRemoteConnCount());
+        if (zks != null) {
+            assertEquals(0, zks.serverStats().getNonMTLSLocalConnCount());
+            assertEquals(1, zks.serverStats().getNonMTLSRemoteConnCount());
+        }
     }
 
     @Test