|
@@ -25,6 +25,10 @@ import java.io.DataOutput;
|
|
|
import java.io.IOException;
|
|
|
import java.security.PrivilegedExceptionAction;
|
|
|
import java.security.Security;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Enumeration;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.TreeMap;
|
|
|
|
|
@@ -38,6 +42,7 @@ import javax.security.sasl.RealmCallback;
|
|
|
import javax.security.sasl.Sasl;
|
|
|
import javax.security.sasl.SaslException;
|
|
|
import javax.security.sasl.SaslServer;
|
|
|
+import javax.security.sasl.SaslServerFactory;
|
|
|
|
|
|
import org.apache.commons.codec.binary.Base64;
|
|
|
import org.apache.commons.logging.Log;
|
|
@@ -63,6 +68,7 @@ public class SaslRpcServer {
|
|
|
public static final String SASL_DEFAULT_REALM = "default";
|
|
|
public static final Map<String, String> SASL_PROPS =
|
|
|
new TreeMap<String, String>();
|
|
|
+ private static SaslServerFactory saslFactory;
|
|
|
|
|
|
public static enum QualityOfProtection {
|
|
|
AUTHENTICATION("auth"),
|
|
@@ -151,7 +157,7 @@ public class SaslRpcServer {
|
|
|
new PrivilegedExceptionAction<SaslServer>() {
|
|
|
@Override
|
|
|
public SaslServer run() throws SaslException {
|
|
|
- return Sasl.createSaslServer(mechanism, protocol, serverId,
|
|
|
+ return saslFactory.createSaslServer(mechanism, protocol, serverId,
|
|
|
SaslRpcServer.SASL_PROPS, callback);
|
|
|
}
|
|
|
});
|
|
@@ -180,6 +186,7 @@ public class SaslRpcServer {
|
|
|
SASL_PROPS.put(Sasl.QOP, saslQOP.getSaslQop());
|
|
|
SASL_PROPS.put(Sasl.SERVER_AUTH, "true");
|
|
|
Security.addProvider(new SaslPlainServer.SecurityProvider());
|
|
|
+ saslFactory = new FastSaslServerFactory(SASL_PROPS);
|
|
|
}
|
|
|
|
|
|
static String encodeIdentifier(byte[] identifier) {
|
|
@@ -363,4 +370,47 @@ public class SaslRpcServer {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ // Sasl.createSaslServer is 100-200X slower than caching the factories!
|
|
|
+ private static class FastSaslServerFactory implements SaslServerFactory {
|
|
|
+ private final Map<String,List<SaslServerFactory>> factoryCache =
|
|
|
+ new HashMap<String,List<SaslServerFactory>>();
|
|
|
+
|
|
|
+ FastSaslServerFactory(Map<String,?> props) {
|
|
|
+ final Enumeration<SaslServerFactory> factories =
|
|
|
+ Sasl.getSaslServerFactories();
|
|
|
+ while (factories.hasMoreElements()) {
|
|
|
+ SaslServerFactory factory = factories.nextElement();
|
|
|
+ for (String mech : factory.getMechanismNames(props)) {
|
|
|
+ if (!factoryCache.containsKey(mech)) {
|
|
|
+ factoryCache.put(mech, new ArrayList<SaslServerFactory>());
|
|
|
+ }
|
|
|
+ factoryCache.get(mech).add(factory);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public SaslServer createSaslServer(String mechanism, String protocol,
|
|
|
+ String serverName, Map<String,?> props, CallbackHandler cbh)
|
|
|
+ throws SaslException {
|
|
|
+ SaslServer saslServer = null;
|
|
|
+ List<SaslServerFactory> factories = factoryCache.get(mechanism);
|
|
|
+ if (factories != null) {
|
|
|
+ for (SaslServerFactory factory : factories) {
|
|
|
+ saslServer = factory.createSaslServer(
|
|
|
+ mechanism, protocol, serverName, props, cbh);
|
|
|
+ if (saslServer != null) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return saslServer;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String[] getMechanismNames(Map<String, ?> props) {
|
|
|
+ return factoryCache.keySet().toArray(new String[0]);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|