|
@@ -26,8 +26,10 @@ import java.io.IOException;
|
|
|
import java.lang.reflect.Field;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
-import java.util.concurrent.atomic.AtomicInteger;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
+import org.apache.zookeeper.ClientCnxn.EventThread;
|
|
|
import org.apache.zookeeper.ClientCnxn.SendThread;
|
|
|
import org.apache.zookeeper.Watcher.Event.KeeperState;
|
|
|
import org.apache.zookeeper.ZooDefs.Ids;
|
|
@@ -88,7 +90,7 @@ public class SaslAuthTest extends ClientBase {
|
|
|
System.clearProperty("java.security.auth.login.config");
|
|
|
}
|
|
|
|
|
|
- private AtomicInteger authFailed = new AtomicInteger(0);
|
|
|
+ private final CountDownLatch authFailed = new CountDownLatch(1);
|
|
|
|
|
|
@Override
|
|
|
protected TestableZooKeeper createClient(String hp)
|
|
@@ -102,7 +104,7 @@ public class SaslAuthTest extends ClientBase {
|
|
|
@Override
|
|
|
public synchronized void process(WatchedEvent event) {
|
|
|
if (event.getState() == KeeperState.AuthFailed) {
|
|
|
- authFailed.incrementAndGet();
|
|
|
+ authFailed.countDown();
|
|
|
}
|
|
|
else {
|
|
|
super.process(event);
|
|
@@ -210,4 +212,41 @@ public class SaslAuthTest extends ClientBase {
|
|
|
saslLoginFailedField.setBoolean(sendThread, true);
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void testThreadsShutdownOnAuthFailed() throws Exception {
|
|
|
+ MyWatcher watcher = new MyWatcher();
|
|
|
+ ZooKeeper zk = null;
|
|
|
+ try {
|
|
|
+ zk = new ZooKeeper(hostPort, CONNECTION_TIMEOUT, watcher);
|
|
|
+ watcher.waitForConnected(CONNECTION_TIMEOUT);
|
|
|
+ try {
|
|
|
+ zk.addAuthInfo("FOO", "BAR".getBytes());
|
|
|
+ zk.getData("/path1", false, null);
|
|
|
+ Assert.fail("Should get auth state error");
|
|
|
+ } catch (KeeperException.AuthFailedException e) {
|
|
|
+ if (!authFailed.await(CONNECTION_TIMEOUT, TimeUnit.MILLISECONDS)) {
|
|
|
+ Assert.fail("Should have called my watcher");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Field cnxnField = zk.getClass().getDeclaredField("cnxn");
|
|
|
+ cnxnField.setAccessible(true);
|
|
|
+ ClientCnxn clientCnxn = (ClientCnxn) cnxnField.get(zk);
|
|
|
+ Field sendThreadField = clientCnxn.getClass().getDeclaredField("sendThread");
|
|
|
+ sendThreadField.setAccessible(true);
|
|
|
+ SendThread sendThread = (SendThread) sendThreadField.get(clientCnxn);
|
|
|
+ Field eventThreadField = clientCnxn.getClass().getDeclaredField("eventThread");
|
|
|
+ eventThreadField.setAccessible(true);
|
|
|
+ EventThread eventThread = (EventThread) eventThreadField.get(clientCnxn);
|
|
|
+ sendThread.join(CONNECTION_TIMEOUT);
|
|
|
+ eventThread.join(CONNECTION_TIMEOUT);
|
|
|
+ Assert.assertFalse("SendThread did not shutdown after authFail", sendThread.isAlive());
|
|
|
+ Assert.assertFalse("EventThread did not shutdown after authFail",
|
|
|
+ eventThread.isAlive());
|
|
|
+ } finally {
|
|
|
+ if (zk != null) {
|
|
|
+ zk.close();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
}
|