|
@@ -17,156 +17,584 @@
|
|
|
*/
|
|
|
package org.apache.zookeeper.server.quorum;
|
|
|
|
|
|
+import java.io.BufferedInputStream;
|
|
|
+import java.io.IOException;
|
|
|
+import java.net.ConnectException;
|
|
|
+import java.net.InetAddress;
|
|
|
+import java.net.InetSocketAddress;
|
|
|
+import java.net.ServerSocket;
|
|
|
+import java.net.Socket;
|
|
|
+import java.net.SocketException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Collection;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Random;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+import javax.net.ssl.HandshakeCompletedEvent;
|
|
|
+import javax.net.ssl.HandshakeCompletedListener;
|
|
|
+import javax.net.ssl.SSLSocket;
|
|
|
+
|
|
|
import org.apache.zookeeper.PortAssignment;
|
|
|
-import org.apache.zookeeper.client.ZKClientConfig;
|
|
|
+import org.apache.zookeeper.common.BaseX509ParameterizedTestCase;
|
|
|
import org.apache.zookeeper.common.ClientX509Util;
|
|
|
-import org.apache.zookeeper.common.Time;
|
|
|
+import org.apache.zookeeper.common.KeyStoreFileType;
|
|
|
+import org.apache.zookeeper.common.X509Exception;
|
|
|
+import org.apache.zookeeper.common.X509KeyType;
|
|
|
+import org.apache.zookeeper.common.X509TestContext;
|
|
|
import org.apache.zookeeper.common.X509Util;
|
|
|
-import org.apache.zookeeper.server.ServerCnxnFactory;
|
|
|
+import org.junit.After;
|
|
|
import org.junit.Assert;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
+import org.junit.runner.RunWith;
|
|
|
+import org.junit.runners.Parameterized;
|
|
|
|
|
|
-import javax.net.ssl.HandshakeCompletedEvent;
|
|
|
-import javax.net.ssl.HandshakeCompletedListener;
|
|
|
-import javax.net.ssl.SSLSocket;
|
|
|
-import java.io.IOException;
|
|
|
-import java.net.ConnectException;
|
|
|
-import java.net.InetSocketAddress;
|
|
|
-import java.net.Socket;
|
|
|
+@RunWith(Parameterized.class)
|
|
|
+public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase {
|
|
|
|
|
|
-import static org.hamcrest.CoreMatchers.equalTo;
|
|
|
-import static org.junit.Assert.assertThat;
|
|
|
-
|
|
|
-public class UnifiedServerSocketTest {
|
|
|
+ @Parameterized.Parameters
|
|
|
+ public static Collection<Object[]> params() {
|
|
|
+ ArrayList<Object[]> result = new ArrayList<>();
|
|
|
+ int paramIndex = 0;
|
|
|
+ for (X509KeyType caKeyType : X509KeyType.values()) {
|
|
|
+ for (X509KeyType certKeyType : X509KeyType.values()) {
|
|
|
+ for (Boolean hostnameVerification : new Boolean[] { true, false }) {
|
|
|
+ result.add(new Object[]{
|
|
|
+ caKeyType,
|
|
|
+ certKeyType,
|
|
|
+ hostnameVerification,
|
|
|
+ paramIndex++
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
|
|
|
private static final int MAX_RETRIES = 5;
|
|
|
private static final int TIMEOUT = 1000;
|
|
|
+ private static final byte[] DATA_TO_CLIENT = "hello client".getBytes();
|
|
|
+ private static final byte[] DATA_FROM_CLIENT = "hello server".getBytes();
|
|
|
|
|
|
private X509Util x509Util;
|
|
|
- private int port;
|
|
|
- private volatile boolean handshakeCompleted;
|
|
|
+ private InetSocketAddress localServerAddress;
|
|
|
+ private final Object handshakeCompletedLock = new Object();
|
|
|
+ // access only inside synchronized(handshakeCompletedLock) { ... } blocks
|
|
|
+ private boolean handshakeCompleted = false;
|
|
|
+
|
|
|
+ public UnifiedServerSocketTest(
|
|
|
+ final X509KeyType caKeyType,
|
|
|
+ final X509KeyType certKeyType,
|
|
|
+ final Boolean hostnameVerification,
|
|
|
+ final Integer paramIndex) {
|
|
|
+ super(paramIndex, () -> {
|
|
|
+ try {
|
|
|
+ return X509TestContext.newBuilder()
|
|
|
+ .setTempDir(tempDir)
|
|
|
+ .setKeyStoreKeyType(certKeyType)
|
|
|
+ .setTrustStoreKeyType(caKeyType)
|
|
|
+ .setHostnameVerification(hostnameVerification)
|
|
|
+ .build();
|
|
|
+ } catch (Exception e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
|
|
|
@Before
|
|
|
public void setUp() throws Exception {
|
|
|
- handshakeCompleted = false;
|
|
|
-
|
|
|
- port = PortAssignment.unique();
|
|
|
+ localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), PortAssignment.unique());
|
|
|
+ x509Util = new ClientX509Util();
|
|
|
+ x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS);
|
|
|
+ }
|
|
|
|
|
|
- String testDataPath = System.getProperty("test.data.dir", "build/test/data");
|
|
|
- System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, "org.apache.zookeeper.server.NettyServerCnxnFactory");
|
|
|
- System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, "org.apache.zookeeper.ClientCnxnSocketNetty");
|
|
|
- System.setProperty(ZKClientConfig.SECURE_CLIENT, "true");
|
|
|
+ @After
|
|
|
+ public void tearDown() throws Exception {
|
|
|
+ x509TestContext.clearSystemProperties(x509Util);
|
|
|
+ }
|
|
|
|
|
|
- x509Util = new ClientX509Util();
|
|
|
+ private static void forceClose(Socket s) {
|
|
|
+ if (s == null || s.isClosed()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ s.close();
|
|
|
+ } catch (IOException e) {
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- System.setProperty(x509Util.getSslKeystoreLocationProperty(), testDataPath + "/ssl/testKeyStore.jks");
|
|
|
- System.setProperty(x509Util.getSslKeystorePasswdProperty(), "testpass");
|
|
|
- System.setProperty(x509Util.getSslTruststoreLocationProperty(), testDataPath + "/ssl/testTrustStore.jks");
|
|
|
- System.setProperty(x509Util.getSslTruststorePasswdProperty(), "testpass");
|
|
|
- System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(), "false");
|
|
|
+ private static void forceClose(ServerSocket s) {
|
|
|
+ if (s == null || s.isClosed()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ s.close();
|
|
|
+ } catch (IOException e) {
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- @Test
|
|
|
- public void testConnectWithSSL() throws Exception {
|
|
|
- class ServerThread extends Thread {
|
|
|
- public void run() {
|
|
|
- try {
|
|
|
- Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept();
|
|
|
- ((SSLSocket)unifiedSocket).getSession(); // block until handshake completes
|
|
|
- } catch (IOException e) {
|
|
|
- e.printStackTrace();
|
|
|
+ private static final class UnifiedServerThread extends Thread {
|
|
|
+ private final byte[] dataToClient;
|
|
|
+ private List<byte[]> dataFromClients;
|
|
|
+ private ExecutorService workerPool;
|
|
|
+ private UnifiedServerSocket serverSocket;
|
|
|
+
|
|
|
+ UnifiedServerThread(X509Util x509Util,
|
|
|
+ InetSocketAddress bindAddress,
|
|
|
+ boolean allowInsecureConnection,
|
|
|
+ byte[] dataToClient) throws IOException {
|
|
|
+ this.dataToClient = dataToClient;
|
|
|
+ dataFromClients = new ArrayList<>();
|
|
|
+ workerPool = Executors.newCachedThreadPool();
|
|
|
+ serverSocket = new UnifiedServerSocket(x509Util, allowInsecureConnection);
|
|
|
+ serverSocket.bind(bindAddress);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void run() {
|
|
|
+ try {
|
|
|
+ Random rnd = new Random();
|
|
|
+ while (true) {
|
|
|
+ final Socket unifiedSocket = serverSocket.accept();
|
|
|
+ final boolean tcpNoDelay = rnd.nextBoolean();
|
|
|
+ unifiedSocket.setTcpNoDelay(tcpNoDelay);
|
|
|
+ unifiedSocket.setSoTimeout(TIMEOUT);
|
|
|
+ final boolean keepAlive = rnd.nextBoolean();
|
|
|
+ unifiedSocket.setKeepAlive(keepAlive);
|
|
|
+ // Note: getting the input stream should not block the thread or trigger mode detection.
|
|
|
+ BufferedInputStream bis = new BufferedInputStream(unifiedSocket.getInputStream());
|
|
|
+ workerPool.submit(new Runnable() {
|
|
|
+ @Override
|
|
|
+ public void run() {
|
|
|
+ try {
|
|
|
+ byte[] buf = new byte[1024];
|
|
|
+ int bytesRead = unifiedSocket.getInputStream().read(buf, 0, 1024);
|
|
|
+ // Make sure the settings applied above before the socket was potentially upgraded to
|
|
|
+ // TLS still apply.
|
|
|
+ Assert.assertEquals(tcpNoDelay, unifiedSocket.getTcpNoDelay());
|
|
|
+ Assert.assertEquals(TIMEOUT, unifiedSocket.getSoTimeout());
|
|
|
+ Assert.assertEquals(keepAlive, unifiedSocket.getKeepAlive());
|
|
|
+ if (bytesRead > 0) {
|
|
|
+ byte[] dataFromClient = new byte[bytesRead];
|
|
|
+ System.arraycopy(buf, 0, dataFromClient, 0, bytesRead);
|
|
|
+ synchronized (dataFromClients) {
|
|
|
+ dataFromClients.add(dataFromClient);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ unifiedSocket.getOutputStream().write(dataToClient);
|
|
|
+ unifiedSocket.getOutputStream().flush();
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ } finally {
|
|
|
+ forceClose(unifiedSocket);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ } finally {
|
|
|
+ forceClose(serverSocket);
|
|
|
+ workerPool.shutdown();
|
|
|
}
|
|
|
}
|
|
|
- ServerThread serverThread = new ServerThread();
|
|
|
- serverThread.start();
|
|
|
|
|
|
+ public void shutdown(long millis) throws InterruptedException {
|
|
|
+ forceClose(serverSocket); // this should break the run() loop
|
|
|
+ workerPool.awaitTermination(millis, TimeUnit.MILLISECONDS);
|
|
|
+ this.join(millis);
|
|
|
+ }
|
|
|
+
|
|
|
+ synchronized byte[] getDataFromClient(int index) {
|
|
|
+ return dataFromClients.get(index);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private SSLSocket connectWithSSL() throws IOException, X509Exception, InterruptedException {
|
|
|
SSLSocket sslSocket = null;
|
|
|
int retries = 0;
|
|
|
while (retries < MAX_RETRIES) {
|
|
|
try {
|
|
|
sslSocket = x509Util.createSSLSocket();
|
|
|
+ sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() {
|
|
|
+ @Override
|
|
|
+ public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) {
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ handshakeCompleted = true;
|
|
|
+ handshakeCompletedLock.notifyAll();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
sslSocket.setSoTimeout(TIMEOUT);
|
|
|
- sslSocket.connect(new InetSocketAddress(port), TIMEOUT);
|
|
|
+ sslSocket.connect(localServerAddress, TIMEOUT);
|
|
|
break;
|
|
|
} catch (ConnectException connectException) {
|
|
|
connectException.printStackTrace();
|
|
|
+ forceClose(sslSocket);
|
|
|
+ sslSocket = null;
|
|
|
Thread.sleep(TIMEOUT);
|
|
|
}
|
|
|
retries++;
|
|
|
}
|
|
|
|
|
|
- sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() {
|
|
|
- @Override
|
|
|
- public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) {
|
|
|
- completeHandshake();
|
|
|
+ Assert.assertNotNull("Failed to connect to server with SSL", sslSocket);
|
|
|
+ return sslSocket;
|
|
|
+ }
|
|
|
+
|
|
|
+ private Socket connectWithoutSSL() throws IOException, InterruptedException {
|
|
|
+ Socket socket = null;
|
|
|
+ int retries = 0;
|
|
|
+ while (retries < MAX_RETRIES) {
|
|
|
+ try {
|
|
|
+ socket = new Socket();
|
|
|
+ socket.setSoTimeout(TIMEOUT);
|
|
|
+ socket.connect(localServerAddress, TIMEOUT);
|
|
|
+ break;
|
|
|
+ } catch (ConnectException connectException) {
|
|
|
+ connectException.printStackTrace();
|
|
|
+ forceClose(socket);
|
|
|
+ socket = null;
|
|
|
+ Thread.sleep(TIMEOUT);
|
|
|
}
|
|
|
- });
|
|
|
- sslSocket.startHandshake();
|
|
|
+ retries++;
|
|
|
+ }
|
|
|
+ Assert.assertNotNull("Failed to connect to server without SSL", socket);
|
|
|
+ return socket;
|
|
|
+ }
|
|
|
+
|
|
|
+ // In the tests below, a "Strict" server means a UnifiedServerSocket that
|
|
|
+ // does not allow plaintext connections (in other words, it's SSL-only).
|
|
|
+ // A "Non Strict" server means a UnifiedServerSocket that allows both
|
|
|
+ // plaintext and SSL incoming connections.
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Attempting to connect to a SSL-or-plaintext server with SSL should work.
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testConnectWithSSLToNonStrictServer() throws Exception {
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ Socket sslSocket = connectWithSSL();
|
|
|
+ try {
|
|
|
+ sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ sslSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
+ }
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(sslSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- serverThread.join(TIMEOUT);
|
|
|
+ /**
|
|
|
+ * Attempting to connect to a SSL-only server with SSL should work.
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testConnectWithSSLToStrictServer() throws Exception {
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ Socket sslSocket = connectWithSSL();
|
|
|
+ try {
|
|
|
+ sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ sslSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
|
|
|
- long start = Time.currentElapsedTime();
|
|
|
- while (Time.currentElapsedTime() < start + TIMEOUT) {
|
|
|
- if (handshakeCompleted) {
|
|
|
- return;
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
}
|
|
|
+
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(sslSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- Assert.fail("failed to complete handshake");
|
|
|
+ /**
|
|
|
+ * Attempting to connect to a SSL-or-plaintext server without SSL should work.
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testConnectWithoutSSLToNonStrictServer() throws Exception {
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ Socket socket = connectWithoutSSL();
|
|
|
+ try {
|
|
|
+ socket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ socket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(socket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private void completeHandshake() {
|
|
|
- handshakeCompleted = true;
|
|
|
+ /**
|
|
|
+ * Attempting to connect to a SSL-or-plaintext server without SSL with a
|
|
|
+ * small initial data write should work. This makes sure that sending
|
|
|
+ * less than 5 bytes does not break the logic in the server's initial 5
|
|
|
+ * byte read.
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testConnectWithoutSSLToNonStrictServerPartialWrite() throws Exception {
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ Socket socket = connectWithoutSSL();
|
|
|
+ try {
|
|
|
+ // Write only 2 bytes of the message, wait a bit, then write the rest.
|
|
|
+ // This makes sure that writes smaller than 5 bytes don't break the plaintext mode on the server
|
|
|
+ // once it decides that the input doesn't look like a TLS handshake.
|
|
|
+ socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2);
|
|
|
+ socket.getOutputStream().flush();
|
|
|
+ Thread.sleep(TIMEOUT / 2);
|
|
|
+ socket.getOutputStream().write(DATA_FROM_CLIENT, 2, DATA_FROM_CLIENT.length - 2);
|
|
|
+ socket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(socket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Attempting to connect to a SSL-only server without SSL should fail.
|
|
|
+ */
|
|
|
@Test
|
|
|
- public void testConnectWithoutSSL() throws Exception {
|
|
|
- final byte[] testData = "hello there".getBytes();
|
|
|
- final String[] dataReadFromClient = {null};
|
|
|
-
|
|
|
- class ServerThread extends Thread {
|
|
|
- public void run() {
|
|
|
- try {
|
|
|
- Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept();
|
|
|
- unifiedSocket.getOutputStream().write(testData);
|
|
|
- unifiedSocket.getOutputStream().flush();
|
|
|
- byte[] inputbuff = new byte[5];
|
|
|
- unifiedSocket.getInputStream().read(inputbuff, 0, 5);
|
|
|
- dataReadFromClient[0] = new String(inputbuff);
|
|
|
- } catch (IOException e) {
|
|
|
- e.printStackTrace();
|
|
|
+ public void testConnectWithoutSSLToStrictServer() throws Exception {
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ Socket socket = connectWithoutSSL();
|
|
|
+ socket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ socket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ try {
|
|
|
+ socket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ } catch (SocketException e) {
|
|
|
+ // We expect the other end to hang up the connection
|
|
|
+ return;
|
|
|
+ } finally {
|
|
|
+ forceClose(socket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.fail("Expected server to hang up the connection. Read from server succeeded unexpectedly.");
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * This test makes sure that UnifiedServerSocket used properly (a single
|
|
|
+ * thread accept()-ing connections and handing the resulting sockets to
|
|
|
+ * other threads for processing) is not vulnerable to blocking the
|
|
|
+ * accept() thread while doing mode detection if a misbehaving client
|
|
|
+ * connects. A misbehaving client is one that either disconnects
|
|
|
+ * immediately, or connects but does not send any data.
|
|
|
+ *
|
|
|
+ * This version of the test uses a non-strict server socket (i.e. it
|
|
|
+ * accepts both TLS and plaintext connections).
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testTLSDetectionNonBlockingNonStrictServerIdleClient() throws Exception {
|
|
|
+ Socket badClientSocket = null;
|
|
|
+ Socket clientSocket = null;
|
|
|
+ Socket secureClientSocket = null;
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ try {
|
|
|
+ badClientSocket = connectWithoutSSL(); // Leave the bad client socket idle
|
|
|
+
|
|
|
+ clientSocket = connectWithoutSSL();
|
|
|
+ clientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ clientSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = clientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ Assert.assertFalse(handshakeCompleted);
|
|
|
+ }
|
|
|
+
|
|
|
+ secureClientSocket = connectWithSSL();
|
|
|
+ secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ secureClientSocket.getOutputStream().flush();
|
|
|
+ buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1));
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
}
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
}
|
|
|
+ } finally {
|
|
|
+ forceClose(badClientSocket);
|
|
|
+ forceClose(clientSocket);
|
|
|
+ forceClose(secureClientSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
}
|
|
|
- ServerThread serverThread = new ServerThread();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Like the above test, but with a strict server socket (closes non-TLS
|
|
|
+ * connections after seeing that there is no handshake).
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testTLSDetectionNonBlockingStrictServerIdleClient() throws Exception {
|
|
|
+ Socket badClientSocket = null;
|
|
|
+ Socket secureClientSocket = null;
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
|
|
|
serverThread.start();
|
|
|
|
|
|
- Socket socket = null;
|
|
|
- int retries = 0;
|
|
|
- while (retries < MAX_RETRIES) {
|
|
|
- try {
|
|
|
- socket = new Socket();
|
|
|
- socket.setSoTimeout(TIMEOUT);
|
|
|
- socket.connect(new InetSocketAddress(port), TIMEOUT);
|
|
|
- break;
|
|
|
- } catch (ConnectException connectException) {
|
|
|
- connectException.printStackTrace();
|
|
|
- Thread.sleep(TIMEOUT);
|
|
|
+ try {
|
|
|
+ badClientSocket = connectWithoutSSL(); // Leave the bad client socket idle
|
|
|
+
|
|
|
+ secureClientSocket = connectWithSSL();
|
|
|
+ secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ secureClientSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
}
|
|
|
- retries++;
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(badClientSocket);
|
|
|
+ forceClose(secureClientSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- socket.getOutputStream().write("hellobello".getBytes());
|
|
|
- socket.getOutputStream().flush();
|
|
|
+ /**
|
|
|
+ * Similar to the tests above, but the bad client disconnects immediately
|
|
|
+ * without sending any data.
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testTLSDetectionNonBlockingNonStrictServerDisconnectedClient() throws Exception {
|
|
|
+ Socket clientSocket = null;
|
|
|
+ Socket secureClientSocket = null;
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ try {
|
|
|
+ Socket badClientSocket = connectWithoutSSL();
|
|
|
+ forceClose(badClientSocket); // close the bad client socket immediately
|
|
|
+
|
|
|
+ clientSocket = connectWithoutSSL();
|
|
|
+ clientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ clientSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = clientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ Assert.assertFalse(handshakeCompleted);
|
|
|
+ }
|
|
|
|
|
|
- byte[] readBytes = new byte[testData.length];
|
|
|
- socket.getInputStream().read(readBytes, 0, testData.length);
|
|
|
+ secureClientSocket = connectWithSSL();
|
|
|
+ secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ secureClientSocket.getOutputStream().flush();
|
|
|
+ buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1));
|
|
|
|
|
|
- serverThread.join(TIMEOUT);
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
+ }
|
|
|
+ } finally {
|
|
|
+ forceClose(clientSocket);
|
|
|
+ forceClose(secureClientSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- Assert.assertArrayEquals(testData, readBytes);
|
|
|
- assertThat("Data sent by the client is invalid", dataReadFromClient[0], equalTo("hello"));
|
|
|
+ /**
|
|
|
+ * Like the above test, but with a strict server socket (closes non-TLS
|
|
|
+ * connections after seeing that there is no handshake).
|
|
|
+ */
|
|
|
+ @Test
|
|
|
+ public void testTLSDetectionNonBlockingStrictServerDisconnectedClient() throws Exception {
|
|
|
+ Socket secureClientSocket = null;
|
|
|
+ UnifiedServerThread serverThread = new UnifiedServerThread(
|
|
|
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
|
|
|
+ serverThread.start();
|
|
|
+
|
|
|
+ try {
|
|
|
+ Socket badClientSocket = connectWithoutSSL();
|
|
|
+ forceClose(badClientSocket); // close the bad client socket immediately
|
|
|
+
|
|
|
+ secureClientSocket = connectWithSSL();
|
|
|
+ secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT);
|
|
|
+ secureClientSocket.getOutputStream().flush();
|
|
|
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
|
|
|
+ int bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length);
|
|
|
+ Assert.assertEquals(buf.length, bytesRead);
|
|
|
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
|
|
|
+
|
|
|
+ synchronized (handshakeCompletedLock) {
|
|
|
+ if (!handshakeCompleted) {
|
|
|
+ handshakeCompletedLock.wait(TIMEOUT);
|
|
|
+ }
|
|
|
+ Assert.assertTrue(handshakeCompleted);
|
|
|
+ }
|
|
|
+ Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0));
|
|
|
+ } finally {
|
|
|
+ forceClose(secureClientSocket);
|
|
|
+ serverThread.shutdown(TIMEOUT);
|
|
|
+ }
|
|
|
}
|
|
|
}
|