Browse Source

HADOOP-10720. KMS: Implement generateEncryptedKey and decryptEncryptedKey in the REST API. (asuresh via tucu)

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1612399 13f79535-47bb-0310-9956-ffa450edef68
Alejandro Abdelnur 10 years ago
parent
commit
0c1469ece3
16 changed files with 1267 additions and 22 deletions
  1. 3 0
      hadoop-common-project/hadoop-common/CHANGES.txt
  2. 35 3
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java
  3. 157 1
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
  4. 8 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSRESTConstants.java
  5. 317 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/ValueQueue.java
  6. 27 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java
  7. 33 0
      hadoop-common-project/hadoop-common/src/main/resources/core-default.xml
  8. 190 0
      hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestValueQueue.java
  9. 15 0
      hadoop-common-project/hadoop-kms/src/main/conf/kms-acls.xml
  10. 149 0
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/EagerKeyGeneratorKeyProviderCryptoExtension.java
  11. 96 1
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMS.java
  12. 4 1
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSACLs.java
  13. 20 1
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSServerJSONUtils.java
  14. 30 4
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSWebApp.java
  15. 83 0
      hadoop-common-project/hadoop-kms/src/site/apt/index.apt.vm
  16. 100 11
      hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java

+ 3 - 0
hadoop-common-project/hadoop-common/CHANGES.txt

@@ -186,6 +186,9 @@ Trunk (Unreleased)
     HADOOP-10750. KMSKeyProviderCache should be in hadoop-common. 
     (asuresh via tucu)
 
+    HADOOP-10720. KMS: Implement generateEncryptedKey and decryptEncryptedKey
+    in the REST API. (asuresh via tucu)
+
   BUG FIXES
 
     HADOOP-9451. Fault single-layer config if node group topology is enabled.

+ 35 - 3
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java

@@ -27,17 +27,19 @@ import javax.crypto.spec.IvParameterSpec;
 import javax.crypto.spec.SecretKeySpec;
 
 import com.google.common.base.Preconditions;
+import org.apache.hadoop.classification.InterfaceAudience;
 
 /**
  * A KeyProvider with Cytographic Extensions specifically for generating
  * Encrypted Keys as well as decrypting them
  * 
  */
+@InterfaceAudience.Private
 public class KeyProviderCryptoExtension extends
     KeyProviderExtension<KeyProviderCryptoExtension.CryptoExtension> {
 
-  protected static final String EEK = "EEK";
-  protected static final String EK = "EK";
+  public static final String EEK = "EEK";
+  public static final String EK = "EK";
 
   /**
    * This is a holder class whose instance contains the keyVersionName, iv
@@ -81,6 +83,14 @@ public class KeyProviderCryptoExtension extends
    */
   public interface CryptoExtension extends KeyProviderExtension.Extension {
 
+    /**
+     * Calls to this method allows the underlying KeyProvider to warm-up any
+     * implementation specific caches used to store the Encrypted Keys.
+     * @param keyNames Array of Key Names
+     */
+    public void warmUpEncryptedKeys(String... keyNames)
+        throws IOException;
+
     /**
      * Generates a key material and encrypts it using the given key version name
      * and initialization vector. The generated key material is of the same
@@ -180,13 +190,35 @@ public class KeyProviderCryptoExtension extends
       return new KeyVersion(keyVer.getName(), EK, ek);
     }
 
+    @Override
+    public void warmUpEncryptedKeys(String... keyNames)
+        throws IOException {
+      // NO-OP since the default version does not cache any keys
+    }
+
   }
 
-  private KeyProviderCryptoExtension(KeyProvider keyProvider,
+  /**
+   * This constructor is to be used by sub classes that provide
+   * delegating/proxying functionality to the {@link KeyProviderCryptoExtension}
+   * @param keyProvider
+   * @param extension
+   */
+  protected KeyProviderCryptoExtension(KeyProvider keyProvider,
       CryptoExtension extension) {
     super(keyProvider, extension);
   }
 
+  /**
+   * Notifies the Underlying CryptoExtension implementation to warm up any
+   * implementation specific caches for the specified KeyVersions
+   * @param keyNames Arrays of key Names
+   */
+  public void warmUpEncryptedKeys(String... keyNames)
+      throws IOException {
+    getExtension().warmUpEncryptedKeys(keyNames);
+  }
+
   /**
    * Generates a key material and encrypts it using the given key version name
    * and initialization vector. The generated key material is of the same

+ 157 - 1
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java

@@ -21,7 +21,9 @@ import org.apache.commons.codec.binary.Base64;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
 import org.apache.hadoop.crypto.key.KeyProviderFactory;
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.security.ProviderUtils;
 import org.apache.hadoop.security.authentication.client.AuthenticatedURL;
@@ -33,6 +35,7 @@ import org.apache.http.client.utils.URIBuilder;
 import org.codehaus.jackson.map.ObjectMapper;
 
 import javax.net.ssl.HttpsURLConnection;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -40,6 +43,7 @@ import java.io.OutputStreamWriter;
 import java.io.Writer;
 import java.lang.reflect.Constructor;
 import java.net.HttpURLConnection;
+import java.net.SocketTimeoutException;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.net.URL;
@@ -50,14 +54,22 @@ import java.text.MessageFormat;
 import java.util.ArrayList;
 import java.util.Date;
 import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ExecutionException;
+
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
+
+import com.google.common.base.Preconditions;
 
 /**
  * KMS client <code>KeyProvider</code> implementation.
  */
 @InterfaceAudience.Private
-public class KMSClientProvider extends KeyProvider {
+public class KMSClientProvider extends KeyProvider implements CryptoExtension {
 
   public static final String SCHEME_NAME = "kms";
 
@@ -78,6 +90,73 @@ public class KMSClientProvider extends KeyProvider {
   public static final String TIMEOUT_ATTR = CONFIG_PREFIX + "timeout";
   public static final int DEFAULT_TIMEOUT = 60;
 
+  private final ValueQueue<EncryptedKeyVersion> encKeyVersionQueue;
+
+  private class EncryptedQueueRefiller implements
+    ValueQueue.QueueRefiller<EncryptedKeyVersion> {
+
+    @Override
+    public void fillQueueForKey(String keyName,
+        Queue<EncryptedKeyVersion> keyQueue, int numEKVs) throws IOException {
+      checkNotNull(keyName, "keyName");
+      Map<String, String> params = new HashMap<String, String>();
+      params.put(KMSRESTConstants.EEK_OP, KMSRESTConstants.EEK_GENERATE);
+      params.put(KMSRESTConstants.EEK_NUM_KEYS, "" + numEKVs);
+      URL url = createURL(KMSRESTConstants.KEY_RESOURCE, keyName,
+          KMSRESTConstants.EEK_SUB_RESOURCE, params);
+      HttpURLConnection conn = createConnection(url, HTTP_GET);
+      conn.setRequestProperty(CONTENT_TYPE, APPLICATION_JSON_MIME);
+      List response = call(conn, null,
+          HttpURLConnection.HTTP_OK, List.class);
+      List<EncryptedKeyVersion> ekvs =
+          parseJSONEncKeyVersion(keyName, response);
+      keyQueue.addAll(ekvs);
+    }
+  }
+
+  public static class KMSEncryptedKeyVersion extends EncryptedKeyVersion {
+    public KMSEncryptedKeyVersion(String keyName, String keyVersionName,
+        byte[] iv, String encryptedVersionName, byte[] keyMaterial) {
+      super(keyName, keyVersionName, iv, new KMSKeyVersion(null, 
+          encryptedVersionName, keyMaterial));
+    }
+  }
+
+  @SuppressWarnings("rawtypes")
+  private static List<EncryptedKeyVersion>
+      parseJSONEncKeyVersion(String keyName, List valueList) {
+    List<EncryptedKeyVersion> ekvs = new LinkedList<EncryptedKeyVersion>();
+    if (!valueList.isEmpty()) {
+      for (Object values : valueList) {
+        Map valueMap = (Map) values;
+
+        String versionName = checkNotNull(
+                (String) valueMap.get(KMSRESTConstants.VERSION_NAME_FIELD),
+                KMSRESTConstants.VERSION_NAME_FIELD);
+
+        byte[] iv = Base64.decodeBase64(checkNotNull(
+                (String) valueMap.get(KMSRESTConstants.IV_FIELD),
+                KMSRESTConstants.IV_FIELD));
+
+        Map encValueMap = checkNotNull((Map)
+                valueMap.get(KMSRESTConstants.ENCRYPTED_KEY_VERSION_FIELD),
+                KMSRESTConstants.ENCRYPTED_KEY_VERSION_FIELD);
+
+        String encVersionName = checkNotNull((String)
+                encValueMap.get(KMSRESTConstants.VERSION_NAME_FIELD),
+                KMSRESTConstants.VERSION_NAME_FIELD);
+
+        byte[] encKeyMaterial = Base64.decodeBase64(checkNotNull((String)
+                encValueMap.get(KMSRESTConstants.MATERIAL_FIELD),
+                KMSRESTConstants.MATERIAL_FIELD));
+
+        ekvs.add(new KMSEncryptedKeyVersion(keyName, versionName, iv,
+            encVersionName, encKeyMaterial));
+      }
+    }
+    return ekvs;
+  }
+
   private static KeyVersion parseJSONKeyVersion(Map valueMap) {
     KeyVersion keyVersion = null;
     if (!valueMap.isEmpty()) {
@@ -208,6 +287,28 @@ public class KMSClientProvider extends KeyProvider {
     }
     int timeout = conf.getInt(TIMEOUT_ATTR, DEFAULT_TIMEOUT);
     configurator = new TimeoutConnConfigurator(timeout, sslFactory);
+    encKeyVersionQueue =
+        new ValueQueue<KeyProviderCryptoExtension.EncryptedKeyVersion>(
+            conf.getInt(
+                CommonConfigurationKeysPublic.KMS_CLIENT_ENC_KEY_CACHE_SIZE,
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT),
+            conf.getFloat(
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_LOW_WATERMARK,
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_LOW_WATERMARK_DEFAULT),
+            conf.getInt(
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_MS,
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_DEFAULT),
+            conf.getInt(
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_NUM_REFILL_THREADS,
+                CommonConfigurationKeysPublic.
+                    KMS_CLIENT_ENC_KEY_CACHE_NUM_REFILL_THREADS_DEFAULT),
+            new EncryptedQueueRefiller());
   }
 
   private String createServiceURL(URL url) throws IOException {
@@ -527,6 +628,51 @@ public class KMSClientProvider extends KeyProvider {
     }
   }
 
+  @Override
+  public EncryptedKeyVersion generateEncryptedKey(
+      String encryptionKeyName) throws IOException, GeneralSecurityException {
+    try {
+      return encKeyVersionQueue.getNext(encryptionKeyName);
+    } catch (ExecutionException e) {
+      if (e.getCause() instanceof SocketTimeoutException) {
+        throw (SocketTimeoutException)e.getCause();
+      }
+      throw new IOException(e);
+    }
+  }
+
+  @SuppressWarnings("rawtypes")
+  @Override
+  public KeyVersion decryptEncryptedKey(
+      EncryptedKeyVersion encryptedKeyVersion) throws IOException,
+                                                      GeneralSecurityException {
+    checkNotNull(encryptedKeyVersion.getKeyVersionName(), "versionName");
+    checkNotNull(encryptedKeyVersion.getIv(), "iv");
+    Preconditions.checkArgument(encryptedKeyVersion.getEncryptedKey()
+        .getVersionName().equals(KeyProviderCryptoExtension.EEK),
+        "encryptedKey version name must be '%s', is '%s'",
+        KeyProviderCryptoExtension.EK, encryptedKeyVersion.getEncryptedKey()
+            .getVersionName());
+    checkNotNull(encryptedKeyVersion.getEncryptedKey(), "encryptedKey");
+    Map<String, String> params = new HashMap<String, String>();
+    params.put(KMSRESTConstants.EEK_OP, KMSRESTConstants.EEK_DECRYPT);
+    Map<String, Object> jsonPayload = new HashMap<String, Object>();
+    jsonPayload.put(KMSRESTConstants.NAME_FIELD,
+        encryptedKeyVersion.getKeyName());
+    jsonPayload.put(KMSRESTConstants.IV_FIELD, Base64.encodeBase64String(
+        encryptedKeyVersion.getIv()));
+    jsonPayload.put(KMSRESTConstants.MATERIAL_FIELD, Base64.encodeBase64String(
+            encryptedKeyVersion.getEncryptedKey().getMaterial()));
+    URL url = createURL(KMSRESTConstants.KEY_VERSION_RESOURCE,
+        encryptedKeyVersion.getKeyVersionName(),
+        KMSRESTConstants.EEK_SUB_RESOURCE, params);
+    HttpURLConnection conn = createConnection(url, HTTP_POST);
+    conn.setRequestProperty(CONTENT_TYPE, APPLICATION_JSON_MIME);
+    Map response =
+        call(conn, jsonPayload, HttpURLConnection.HTTP_OK, Map.class);
+    return parseJSONKeyVersion(response);
+  }
+
   @Override
   public List<KeyVersion> getKeyVersions(String name) throws IOException {
     checkNotEmpty(name, "name");
@@ -570,4 +716,14 @@ public class KMSClientProvider extends KeyProvider {
     // the server should not keep in memory state on behalf of clients either.
   }
 
+  @Override
+  public void warmUpEncryptedKeys(String... keyNames)
+      throws IOException {
+    try {
+      encKeyVersionQueue.initializeQueuesForKeys(keyNames);
+    } catch (ExecutionException e) {
+      throw new IOException(e);
+    }
+  }
+
 }

+ 8 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSRESTConstants.java

@@ -34,10 +34,16 @@ public class KMSRESTConstants {
   public static final String KEY_VERSION_RESOURCE = "keyversion";
   public static final String METADATA_SUB_RESOURCE = "_metadata";
   public static final String VERSIONS_SUB_RESOURCE = "_versions";
+  public static final String EEK_SUB_RESOURCE = "_eek";
   public static final String CURRENT_VERSION_SUB_RESOURCE = "_currentversion";
 
   public static final String KEY_OP = "key";
+  public static final String EEK_OP = "eek_op";
+  public static final String EEK_GENERATE = "generate";
+  public static final String EEK_DECRYPT = "decrypt";
+  public static final String EEK_NUM_KEYS = "num_keys";
 
+  public static final String IV_FIELD = "iv";
   public static final String NAME_FIELD = "name";
   public static final String CIPHER_FIELD = "cipher";
   public static final String LENGTH_FIELD = "length";
@@ -47,6 +53,8 @@ public class KMSRESTConstants {
   public static final String VERSIONS_FIELD = "versions";
   public static final String MATERIAL_FIELD = "material";
   public static final String VERSION_NAME_FIELD = "versionName";
+  public static final String ENCRYPTED_KEY_VERSION_FIELD =
+      "encryptedKeyVersion";
 
   public static final String ERROR_EXCEPTION_JSON = "exception";
   public static final String ERROR_MESSAGE_JSON = "message";

+ 317 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/ValueQueue.java

@@ -0,0 +1,317 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.crypto.key.kms;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.hadoop.classification.InterfaceAudience;
+
+/**
+ * A Utility class that maintains a Queue of entries for a given key. It tries
+ * to ensure that there is are always at-least <code>numValues</code> entries
+ * available for the client to consume for a particular key.
+ * It also uses an underlying Cache to evict queues for keys that have not been
+ * accessed for a configurable period of time.
+ * Implementing classes are required to implement the
+ * <code>QueueRefiller</code> interface that exposes a method to refill the
+ * queue, when empty
+ */
+@InterfaceAudience.Private
+public class ValueQueue <E> {
+
+  /**
+   * QueueRefiller interface a client must implement to use this class
+   */
+  public interface QueueRefiller <E> {
+    /**
+     * Method that has to be implemented by implementing classes to fill the
+     * Queue.
+     * @param keyName Key name
+     * @param keyQueue Queue that needs to be filled
+     * @param numValues number of Values to be added to the queue.
+     * @throws IOException
+     */
+    public void fillQueueForKey(String keyName,
+        Queue<E> keyQueue, int numValues) throws IOException;
+  }
+
+  private static final String REFILL_THREAD =
+      ValueQueue.class.getName() + "_thread";
+
+  private final LoadingCache<String, LinkedBlockingQueue<E>> keyQueues;
+  private final ThreadPoolExecutor executor;
+  private final UniqueKeyBlockingQueue queue = new UniqueKeyBlockingQueue();
+  private final QueueRefiller<E> refiller;
+  private final SyncGenerationPolicy policy;
+
+  private final int numValues;
+  private final float lowWatermark;
+
+  /**
+   * A <code>Runnable</code> which takes a string name.
+   */
+  private abstract static class NamedRunnable implements Runnable {
+    final String name;
+    private NamedRunnable(String keyName) {
+      this.name = keyName;
+    }
+  }
+
+  /**
+   * This backing blocking queue used in conjunction with the
+   * <code>ThreadPoolExecutor</code> used by the <code>ValueQueue</code>. This
+   * Queue accepts a task only if the task is not currently in the process
+   * of being run by a thread which is implied by the presence of the key
+   * in the <code>keysInProgress</code> set.
+   *
+   * NOTE: Only methods that ware explicitly called by the
+   * <code>ThreadPoolExecutor</code> need to be over-ridden.
+   */
+  private static class UniqueKeyBlockingQueue extends
+      LinkedBlockingQueue<Runnable> {
+
+    private static final long serialVersionUID = -2152747693695890371L;
+    private HashSet<String> keysInProgress = new HashSet<String>();
+
+    @Override
+    public synchronized void put(Runnable e) throws InterruptedException {
+      if (keysInProgress.add(((NamedRunnable)e).name)) {
+        super.put(e);
+      }
+    }
+
+    @Override
+    public Runnable take() throws InterruptedException {
+      Runnable k = super.take();
+      if (k != null) {
+        keysInProgress.remove(((NamedRunnable)k).name);
+      }
+      return k;
+    }
+
+    @Override
+    public Runnable poll(long timeout, TimeUnit unit)
+        throws InterruptedException {
+      Runnable k = super.poll(timeout, unit);
+      if (k != null) {
+        keysInProgress.remove(((NamedRunnable)k).name);
+      }
+      return k;
+    }
+
+  }
+
+  /**
+   * Policy to decide how many values to return to client when client asks for
+   * "n" values and Queue is empty.
+   * This decides how many values to return when client calls "getAtMost"
+   */
+  public static enum SyncGenerationPolicy {
+    ATLEAST_ONE, // Return atleast 1 value
+    LOW_WATERMARK, // Return min(n, lowWatermark * numValues) values
+    ALL // Return n values
+  }
+
+  /**
+   * Constructor takes the following tunable configuration parameters
+   * @param numValues The number of values cached in the Queue for a
+   *    particular key.
+   * @param lowWatermark The ratio of (number of current entries/numValues)
+   *    below which the <code>fillQueueForKey()</code> funciton will be
+   *    invoked to fill the Queue.
+   * @param expiry Expiry time after which the Key and associated Queue are
+   *    evicted from the cache.
+   * @param numFillerThreads Number of threads to use for the filler thread
+   * @param policy The SyncGenerationPolicy to use when client
+   *    calls "getAtMost"
+   * @param refiller implementation of the QueueRefiller
+   */
+  public ValueQueue(final int numValues, final float lowWatermark,
+      long expiry, int numFillerThreads, SyncGenerationPolicy policy,
+      final QueueRefiller<E> refiller) {
+    Preconditions.checkArgument(numValues > 0, "\"numValues\" must be > 0");
+    Preconditions.checkArgument(((lowWatermark > 0)&&(lowWatermark <= 1)),
+        "\"lowWatermark\" must be > 0 and <= 1");
+    Preconditions.checkArgument(expiry > 0, "\"expiry\" must be > 0");
+    Preconditions.checkArgument(numFillerThreads > 0,
+        "\"numFillerThreads\" must be > 0");
+    Preconditions.checkNotNull(policy, "\"policy\" must not be null");
+    this.refiller = refiller;
+    this.policy = policy;
+    this.numValues = numValues;
+    this.lowWatermark = lowWatermark;
+    keyQueues = CacheBuilder.newBuilder()
+            .expireAfterAccess(expiry, TimeUnit.MILLISECONDS)
+            .build(new CacheLoader<String, LinkedBlockingQueue<E>>() {
+                  @Override
+                  public LinkedBlockingQueue<E> load(String keyName)
+                      throws Exception {
+                    LinkedBlockingQueue<E> keyQueue =
+                        new LinkedBlockingQueue<E>();
+                    refiller.fillQueueForKey(keyName, keyQueue,
+                        (int)(lowWatermark * numValues));
+                    return keyQueue;
+                  }
+                });
+
+    executor =
+        new ThreadPoolExecutor(numFillerThreads, numFillerThreads, 0L,
+            TimeUnit.MILLISECONDS, queue, new ThreadFactoryBuilder()
+                .setDaemon(true)
+                .setNameFormat(REFILL_THREAD).build());
+    // To ensure all requests are first queued, make coreThreads = maxThreads
+    // and pre-start all the Core Threads.
+    executor.prestartAllCoreThreads();
+  }
+
+  public ValueQueue(final int numValues, final float lowWaterMark, long expiry,
+      int numFillerThreads, QueueRefiller<E> fetcher) {
+    this(numValues, lowWaterMark, expiry, numFillerThreads,
+        SyncGenerationPolicy.ALL, fetcher);
+  }
+
+  /**
+   * Initializes the Value Queues for the provided keys by calling the
+   * fill Method with "numInitValues" values
+   * @param keyNames Array of key Names
+   * @throws ExecutionException
+   */
+  public void initializeQueuesForKeys(String... keyNames)
+      throws ExecutionException {
+    for (String keyName : keyNames) {
+      keyQueues.get(keyName);
+    }
+  }
+
+  /**
+   * This removes the value currently at the head of the Queue for the
+   * provided key. Will immediately fire the Queue filler function if key
+   * does not exist.
+   * If Queue exists but all values are drained, It will ask the generator
+   * function to add 1 value to Queue and then drain it.
+   * @param keyName String key name
+   * @return E the next value in the Queue
+   * @throws IOException
+   * @throws ExecutionException
+   */
+  public E getNext(String keyName)
+      throws IOException, ExecutionException {
+    return getAtMost(keyName, 1).get(0);
+  }
+
+  /**
+   * This removes the "num" values currently at the head of the Queue for the
+   * provided key. Will immediately fire the Queue filler function if key
+   * does not exist
+   * How many values are actually returned is governed by the
+   * <code>SyncGenerationPolicy</code> specified by the user.
+   * @param keyName String key name
+   * @param num Minimum number of values to return.
+   * @return List<E> values returned
+   * @throws IOException
+   * @throws ExecutionException
+   */
+  public List<E> getAtMost(String keyName, int num) throws IOException,
+      ExecutionException {
+    LinkedBlockingQueue<E> keyQueue = keyQueues.get(keyName);
+    // Using poll to avoid race condition..
+    LinkedList<E> ekvs = new LinkedList<E>();
+    try {
+      for (int i = 0; i < num; i++) {
+        E val = keyQueue.poll();
+        // If queue is empty now, Based on the provided SyncGenerationPolicy,
+        // figure out how many new values need to be generated synchronously
+        if (val == null) {
+          // Synchronous call to get remaining values
+          int numToFill = 0;
+          switch (policy) {
+          case ATLEAST_ONE:
+            numToFill = (ekvs.size() < 1) ? 1 : 0;
+            break;
+          case LOW_WATERMARK:
+            numToFill =
+                Math.min(num, (int) (lowWatermark * numValues)) - ekvs.size();
+            break;
+          case ALL:
+            numToFill = num - ekvs.size();
+            break;
+          }
+          // Synchronous fill if not enough values found
+          if (numToFill > 0) {
+            refiller.fillQueueForKey(keyName, ekvs, numToFill);
+          }
+          // Asynch task to fill > lowWatermark
+          if (i <= (int) (lowWatermark * numValues)) {
+            submitRefillTask(keyName, keyQueue);
+          }
+          return ekvs;
+        }
+        ekvs.add(val);
+      }
+    } catch (Exception e) {
+      throw new IOException("Exeption while contacting value generator ", e);
+    }
+    return ekvs;
+  }
+
+  private void submitRefillTask(final String keyName,
+      final Queue<E> keyQueue) throws InterruptedException {
+    // The submit/execute method of the ThreadPoolExecutor is bypassed and
+    // the Runnable is directly put in the backing BlockingQueue so that we
+    // can control exactly how the runnable is inserted into the queue.
+    queue.put(
+        new NamedRunnable(keyName) {
+          @Override
+          public void run() {
+            int cacheSize = numValues;
+            int threshold = (int) (lowWatermark * (float) cacheSize);
+            // Need to ensure that only one refill task per key is executed
+            try {
+              if (keyQueue.size() < threshold) {
+                refiller.fillQueueForKey(name, keyQueue,
+                    cacheSize - keyQueue.size());
+              }
+            } catch (final Exception e) {
+              throw new RuntimeException(e);
+            }
+          }
+        }
+        );
+  }
+
+  /**
+   * Cleanly shutdown
+   */
+  public void shutdown() {
+    executor.shutdownNow();
+  }
+
+}

+ 27 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java

@@ -285,5 +285,32 @@ public class CommonConfigurationKeysPublic {
   /** Class to override Impersonation provider */
   public static final String  HADOOP_SECURITY_IMPERSONATION_PROVIDER_CLASS =
     "hadoop.security.impersonation.provider.class";
+
+  //  <!--- KMSClientProvider configurations —>
+  /** See <a href="{@docRoot}/../core-default.html">core-default.xml</a> */
+  public static final String KMS_CLIENT_ENC_KEY_CACHE_SIZE =
+      "hadoop.security.kms.client.encrypted.key.cache.size";
+  /** Default value for KMS_CLIENT_ENC_KEY_CACHE_SIZE */
+  public static final int KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT = 500;
+
+  /** See <a href="{@docRoot}/../core-default.html">core-default.xml</a> */
+  public static final String KMS_CLIENT_ENC_KEY_CACHE_LOW_WATERMARK =
+      "hadoop.security.kms.client.encrypted.key.cache.low-watermark";
+  /** Default value for KMS_CLIENT_ENC_KEY_CACHE_LOW_WATERMARK */
+  public static final float KMS_CLIENT_ENC_KEY_CACHE_LOW_WATERMARK_DEFAULT =
+      0.3f;
+
+  /** See <a href="{@docRoot}/../core-default.html">core-default.xml</a> */
+  public static final String KMS_CLIENT_ENC_KEY_CACHE_NUM_REFILL_THREADS =
+      "hadoop.security.kms.client.encrypted.key.cache.num.refill.threads";
+  /** Default value for KMS_CLIENT_ENC_KEY_NUM_REFILL_THREADS */
+  public static final int KMS_CLIENT_ENC_KEY_CACHE_NUM_REFILL_THREADS_DEFAULT =
+      2;
+
+  /** See <a href="{@docRoot}/../core-default.html">core-default.xml</a> */
+  public static final String KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_MS =
+      "hadoop.security.kms.client.encrypted.key.cache.expiry";
+  /** Default value for KMS_CLIENT_ENC_KEY_CACHE_EXPIRY (12 hrs)*/
+  public static final int KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_DEFAULT = 43200000;
 }
 

+ 33 - 0
hadoop-common-project/hadoop-common/src/main/resources/core-default.xml

@@ -1455,4 +1455,37 @@ for ldap providers in the same way as above does.
   <value>true</value>
   <description>Don't cache 'har' filesystem instances.</description>
 </property>
+
+<!--- KMSClientProvider configurations -->
+<property>
+  <name>hadoop.security.kms.client.encrypted.key.cache.size</name>
+  <value>500</value>
+  <description>
+    Size of the EncryptedKeyVersion cache Queue for each key
+  </description>
+</property>
+<property>
+  <name>hadoop.security.kms.client.encrypted.key.cache.low-watermark</name>
+  <value>0.3f</value>
+  <description>
+    If size of the EncryptedKeyVersion cache Queue falls below the
+    low watermark, this cache queue will be scheduled for a refill
+  </description>
+</property>
+<property>
+  <name>hadoop.security.kms.client.encrypted.key.cache.num.refill.threads</name>
+  <value>2</value>
+  <description>
+    Number of threads to use for refilling depleted EncryptedKeyVersion
+    cache Queues
+  </description>
+</property>
+<property>
+  <name>"hadoop.security.kms.client.encrypted.key.cache.expiry</name>
+  <value>43200000</value>
+  <description>
+    Cache expiry time for a Key, after which the cache Queue for this
+    key will be dropped. Default = 12hrs
+  </description>
+</property>
 </configuration>

+ 190 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestValueQueue.java

@@ -0,0 +1,190 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.crypto.key;
+
+import java.io.IOException;
+import java.util.Queue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.hadoop.crypto.key.kms.ValueQueue;
+import org.apache.hadoop.crypto.key.kms.ValueQueue.QueueRefiller;
+import org.apache.hadoop.crypto.key.kms.ValueQueue.SyncGenerationPolicy;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.common.collect.Sets;
+
+public class TestValueQueue {
+
+  private static class FillInfo {
+    final int num;
+    final String key;
+    FillInfo(int num, String key) {
+      this.num = num;
+      this.key = key;
+    }
+  }
+
+  private static class MockFiller implements QueueRefiller<String> {
+    final LinkedBlockingQueue<FillInfo> fillCalls =
+        new LinkedBlockingQueue<FillInfo>();
+    @Override
+    public void fillQueueForKey(String keyName, Queue<String> keyQueue,
+        int numValues) throws IOException {
+      fillCalls.add(new FillInfo(numValues, keyName));
+      for(int i = 0; i < numValues; i++) {
+        keyQueue.add("test");
+      }
+    }
+    public FillInfo getTop() throws InterruptedException {
+      return fillCalls.poll(500, TimeUnit.MILLISECONDS);
+    }
+  }
+
+  /**
+   * Verifies that Queue is initially filled to "numInitValues"
+   */
+  @Test
+  public void testInitFill() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.1f, 300, 1,
+            SyncGenerationPolicy.ALL, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(1, filler.getTop().num);
+    vq.shutdown();
+  }
+
+  /**
+   * Verifies that Queue is initialized (Warmed-up) for provided keys
+   */
+  @Test
+  public void testWarmUp() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.5f, 300, 1,
+            SyncGenerationPolicy.ALL, filler);
+    vq.initializeQueuesForKeys("k1", "k2", "k3");
+    FillInfo[] fillInfos =
+      {filler.getTop(), filler.getTop(), filler.getTop()};
+    Assert.assertEquals(5, fillInfos[0].num);
+    Assert.assertEquals(5, fillInfos[1].num);
+    Assert.assertEquals(5, fillInfos[2].num);
+    Assert.assertEquals(Sets.newHashSet("k1", "k2", "k3"),
+        Sets.newHashSet(fillInfos[0].key,
+            fillInfos[1].key,
+            fillInfos[2].key));
+    vq.shutdown();
+  }
+
+  /**
+   * Verifies that the refill task is executed after "checkInterval" if
+   * num values below "lowWatermark"
+   */
+  @Test
+  public void testRefill() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.1f, 300, 1,
+            SyncGenerationPolicy.ALL, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(1, filler.getTop().num);
+    // Trigger refill
+    vq.getNext("k1");
+    Assert.assertEquals(1, filler.getTop().num);
+    Assert.assertEquals(10, filler.getTop().num);
+    vq.shutdown();
+  }
+
+  /**
+   * Verifies that the No refill Happens after "checkInterval" if
+   * num values above "lowWatermark"
+   */
+  @Test
+  public void testNoRefill() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.5f, 300, 1,
+            SyncGenerationPolicy.ALL, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(5, filler.getTop().num);
+    Assert.assertEquals(null, filler.getTop());
+    vq.shutdown();
+  }
+
+  /**
+   * Verify getAtMost when SyncGeneration Policy = ALL
+   */
+  @Test
+  public void testgetAtMostPolicyALL() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.1f, 300, 1,
+            SyncGenerationPolicy.ALL, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(1, filler.getTop().num);
+    // Drain completely
+    Assert.assertEquals(10, vq.getAtMost("k1", 10).size());
+    // Synchronous call
+    Assert.assertEquals(10, filler.getTop().num);
+    // Ask for more... return all
+    Assert.assertEquals(19, vq.getAtMost("k1", 19).size());
+    // Synchronous call (No Async call since num > lowWatermark)
+    Assert.assertEquals(19, filler.getTop().num);
+    vq.shutdown();
+  }
+
+  /**
+   * Verify getAtMost when SyncGeneration Policy = ALL
+   */
+  @Test
+  public void testgetAtMostPolicyATLEAST_ONE() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.3f, 300, 1,
+            SyncGenerationPolicy.ATLEAST_ONE, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(3, filler.getTop().num);
+    // Drain completely
+    Assert.assertEquals(2, vq.getAtMost("k1", 10).size());
+    // Asynch Refill call
+    Assert.assertEquals(10, filler.getTop().num);
+    vq.shutdown();
+  }
+
+  /**
+   * Verify getAtMost when SyncGeneration Policy = LOW_WATERMARK
+   */
+  @Test
+  public void testgetAtMostPolicyLOW_WATERMARK() throws Exception {
+    MockFiller filler = new MockFiller();
+    ValueQueue<String> vq =
+        new ValueQueue<String>(10, 0.3f, 300, 1,
+            SyncGenerationPolicy.LOW_WATERMARK, filler);
+    Assert.assertEquals("test", vq.getNext("k1"));
+    Assert.assertEquals(3, filler.getTop().num);
+    // Drain completely
+    Assert.assertEquals(3, vq.getAtMost("k1", 10).size());
+    // Synchronous call
+    Assert.assertEquals(1, filler.getTop().num);
+    // Asynch Refill call
+    Assert.assertEquals(10, filler.getTop().num);
+    vq.shutdown();
+  }
+}

+ 15 - 0
hadoop-common-project/hadoop-kms/src/main/conf/kms-acls.xml

@@ -79,4 +79,19 @@
     </description>
   </property>
 
+  <property>
+    <name>hadoop.kms.acl.GENERATE_EEK</name>
+    <value>*</value>
+    <description>
+      ACL for generateEncryptedKey CryptoExtension operations
+    </description>
+  </property>
+
+  <property>
+    <name>hadoop.kms.acl.DECRYPT_EEK</name>
+    <value>*</value>
+    <description>
+      ACL for decrypt EncryptedKey CryptoExtension operations
+    </description>
+  </property>
 </configuration>

+ 149 - 0
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/EagerKeyGeneratorKeyProviderCryptoExtension.java

@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.crypto.key.kms.server;
+
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ExecutionException;
+
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.kms.ValueQueue;
+import org.apache.hadoop.crypto.key.kms.ValueQueue.SyncGenerationPolicy;
+
+/**
+ * A {@link KeyProviderCryptoExtension} that pre-generates and caches encrypted 
+ * keys.
+ */
+@InterfaceAudience.Private
+public class EagerKeyGeneratorKeyProviderCryptoExtension 
+    extends KeyProviderCryptoExtension {
+
+  private static final String KEY_CACHE_PREFIX =
+      "hadoop.security.kms.encrypted.key.cache.";
+
+  public static final String KMS_KEY_CACHE_SIZE =
+      KEY_CACHE_PREFIX + "size";
+  public static final int KMS_KEY_CACHE_SIZE_DEFAULT = 100;
+
+  public static final String KMS_KEY_CACHE_LOW_WATERMARK =
+      KEY_CACHE_PREFIX + "low.watermark";
+  public static final float KMS_KEY_CACHE_LOW_WATERMARK_DEFAULT = 0.30f;
+
+  public static final String KMS_KEY_CACHE_EXPIRY_MS =
+      KEY_CACHE_PREFIX + "expiry";
+  public static final int KMS_KEY_CACHE_EXPIRY_DEFAULT = 43200000;
+
+  public static final String KMS_KEY_CACHE_NUM_REFILL_THREADS =
+      KEY_CACHE_PREFIX + "num.fill.threads";
+  public static final int KMS_KEY_CACHE_NUM_REFILL_THREADS_DEFAULT = 2;
+
+
+  private static class CryptoExtension 
+      implements KeyProviderCryptoExtension.CryptoExtension {
+
+    private class EncryptedQueueRefiller implements
+        ValueQueue.QueueRefiller<EncryptedKeyVersion> {
+
+      @Override
+      public void fillQueueForKey(String keyName,
+          Queue<EncryptedKeyVersion> keyQueue, int numKeys) throws IOException {
+        List<EncryptedKeyVersion> retEdeks =
+            new LinkedList<EncryptedKeyVersion>();
+        for (int i = 0; i < numKeys; i++) {
+          try {
+            retEdeks.add(keyProviderCryptoExtension.generateEncryptedKey(
+                keyName));
+          } catch (GeneralSecurityException e) {
+            throw new IOException(e);
+          }
+        }
+        keyQueue.addAll(retEdeks);
+      }
+    }
+
+    private KeyProviderCryptoExtension keyProviderCryptoExtension;
+    private final ValueQueue<EncryptedKeyVersion> encKeyVersionQueue;
+
+    public CryptoExtension(Configuration conf, 
+        KeyProviderCryptoExtension keyProviderCryptoExtension) {
+      this.keyProviderCryptoExtension = keyProviderCryptoExtension;
+      encKeyVersionQueue =
+          new ValueQueue<KeyProviderCryptoExtension.EncryptedKeyVersion>(
+              conf.getInt(KMS_KEY_CACHE_SIZE,
+                  KMS_KEY_CACHE_SIZE_DEFAULT),
+              conf.getFloat(KMS_KEY_CACHE_LOW_WATERMARK,
+                  KMS_KEY_CACHE_LOW_WATERMARK_DEFAULT),
+              conf.getInt(KMS_KEY_CACHE_EXPIRY_MS,
+                  KMS_KEY_CACHE_EXPIRY_DEFAULT),
+              conf.getInt(KMS_KEY_CACHE_NUM_REFILL_THREADS,
+                  KMS_KEY_CACHE_NUM_REFILL_THREADS_DEFAULT),
+              SyncGenerationPolicy.LOW_WATERMARK, new EncryptedQueueRefiller()
+          );
+    }
+
+    @Override
+    public void warmUpEncryptedKeys(String... keyNames) throws
+                                                        IOException {
+      try {
+        encKeyVersionQueue.initializeQueuesForKeys(keyNames);
+      } catch (ExecutionException e) {
+        throw new IOException(e);
+      }
+    }
+
+    @Override
+    public EncryptedKeyVersion generateEncryptedKey(String encryptionKeyName)
+        throws IOException, GeneralSecurityException {
+      try {
+        return encKeyVersionQueue.getNext(encryptionKeyName);
+      } catch (ExecutionException e) {
+        throw new IOException(e);
+      }
+    }
+
+    @Override
+    public KeyVersion
+    decryptEncryptedKey(EncryptedKeyVersion encryptedKeyVersion)
+        throws IOException, GeneralSecurityException {
+      return keyProviderCryptoExtension.decryptEncryptedKey(
+          encryptedKeyVersion);
+    }
+  }
+
+  /**
+   * This class is a proxy for a <code>KeyProviderCryptoExtension</code> that
+   * decorates the underlying <code>CryptoExtension</code> with one that eagerly
+   * caches pre-generated Encrypted Keys using a <code>ValueQueue</code>
+   * 
+   * @param conf Configuration object to load parameters from
+   * @param keyProviderCryptoExtension <code>KeyProviderCryptoExtension</code>
+   * to delegate calls to.
+   */
+  public EagerKeyGeneratorKeyProviderCryptoExtension(Configuration conf,
+      KeyProviderCryptoExtension keyProviderCryptoExtension) {
+    super(keyProviderCryptoExtension, 
+        new CryptoExtension(conf, keyProviderCryptoExtension));
+  }
+
+}

+ 96 - 1
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMS.java

@@ -20,6 +20,8 @@ package org.apache.hadoop.crypto.key.kms.server;
 import org.apache.commons.codec.binary.Base64;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
 import org.apache.hadoop.crypto.key.kms.KMSRESTConstants;
 import org.apache.hadoop.security.AccessControlException;
 import org.apache.hadoop.security.authentication.client.AuthenticationException;
@@ -29,6 +31,7 @@ import org.apache.hadoop.util.StringUtils;
 
 import javax.ws.rs.Consumes;
 import javax.ws.rs.DELETE;
+import javax.ws.rs.DefaultValue;
 import javax.ws.rs.GET;
 import javax.ws.rs.POST;
 import javax.ws.rs.Path;
@@ -39,10 +42,14 @@ import javax.ws.rs.core.Context;
 import javax.ws.rs.core.MediaType;
 import javax.ws.rs.core.Response;
 import javax.ws.rs.core.SecurityContext;
+
+import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.security.Principal;
 import java.text.MessageFormat;
+import java.util.ArrayList;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 
@@ -61,8 +68,10 @@ public class KMS {
   private static final String GET_CURRENT_KEY = "GET_CURRENT_KEY";
   private static final String GET_KEY_VERSIONS = "GET_KEY_VERSIONS";
   private static final String GET_METADATA = "GET_METADATA";
+  private static final String GENERATE_EEK = "GENERATE_EEK";
+  private static final String DECRYPT_EEK = "DECRYPT_EEK";
 
-  private KeyProvider provider;
+  private KeyProviderCryptoExtension provider;
 
   public KMS() throws Exception {
     provider = KMSWebApp.getKeyProvider();
@@ -289,6 +298,92 @@ public class KMS {
     return Response.ok().type(MediaType.APPLICATION_JSON).entity(json).build();
   }
 
+  @SuppressWarnings({ "rawtypes", "unchecked" })
+  @GET
+  @Path(KMSRESTConstants.KEY_RESOURCE + "/{name:.*}/" +
+      KMSRESTConstants.EEK_SUB_RESOURCE)
+  @Produces(MediaType.APPLICATION_JSON)
+  public Response generateEncryptedKeys(
+          @Context SecurityContext securityContext,
+          @PathParam("name") String name,
+          @QueryParam(KMSRESTConstants.EEK_OP) String edekOp,
+          @DefaultValue("1")
+          @QueryParam(KMSRESTConstants.EEK_NUM_KEYS) int numKeys)
+          throws Exception {
+    Principal user = getPrincipal(securityContext);
+    KMSClientProvider.checkNotEmpty(name, "name");
+    KMSClientProvider.checkNotNull(edekOp, "eekOp");
+
+    Object retJSON;
+    if (edekOp.equals(KMSRESTConstants.EEK_GENERATE)) {
+      assertAccess(KMSACLs.Type.GENERATE_EEK, user, GENERATE_EEK, name);
+
+      List<EncryptedKeyVersion> retEdeks =
+          new LinkedList<EncryptedKeyVersion>();
+      try {
+        for (int i = 0; i < numKeys; i ++) {
+          retEdeks.add(provider.generateEncryptedKey(name));
+        }
+      } catch (Exception e) {
+        throw new IOException(e);
+      }
+      KMSAudit.ok(user, GENERATE_EEK, name, "");
+      retJSON = new ArrayList();
+      for (EncryptedKeyVersion edek : retEdeks) {
+        ((ArrayList)retJSON).add(KMSServerJSONUtils.toJSON(edek));
+      }
+    } else {
+      throw new IllegalArgumentException("Wrong " + KMSRESTConstants.EEK_OP +
+          " value, it must be " + KMSRESTConstants.EEK_GENERATE + " or " +
+          KMSRESTConstants.EEK_DECRYPT);
+    }
+    KMSWebApp.getGenerateEEKCallsMeter().mark();
+    return Response.ok().type(MediaType.APPLICATION_JSON).entity(retJSON)
+        .build();
+  }
+
+  @SuppressWarnings("rawtypes")
+  @POST
+  @Path(KMSRESTConstants.KEY_VERSION_RESOURCE + "/{versionName:.*}/" +
+      KMSRESTConstants.EEK_SUB_RESOURCE)
+  @Produces(MediaType.APPLICATION_JSON)
+  public Response decryptEncryptedKey(@Context SecurityContext securityContext,
+      @PathParam("versionName") String versionName,
+      @QueryParam(KMSRESTConstants.EEK_OP) String eekOp,
+      Map jsonPayload)
+      throws Exception {
+    Principal user = getPrincipal(securityContext);
+    KMSClientProvider.checkNotEmpty(versionName, "versionName");
+    KMSClientProvider.checkNotNull(eekOp, "eekOp");
+
+    String keyName = (String) jsonPayload.get(KMSRESTConstants.NAME_FIELD);
+    String ivStr = (String) jsonPayload.get(KMSRESTConstants.IV_FIELD);
+    String encMaterialStr = 
+        (String) jsonPayload.get(KMSRESTConstants.MATERIAL_FIELD);
+    Object retJSON;
+    if (eekOp.equals(KMSRESTConstants.EEK_DECRYPT)) {
+      assertAccess(KMSACLs.Type.DECRYPT_EEK, user, DECRYPT_EEK, versionName);
+      KMSClientProvider.checkNotNull(ivStr, KMSRESTConstants.IV_FIELD);
+      byte[] iv = Base64.decodeBase64(ivStr);
+      KMSClientProvider.checkNotNull(encMaterialStr,
+          KMSRESTConstants.MATERIAL_FIELD);
+      byte[] encMaterial = Base64.decodeBase64(encMaterialStr);
+      KeyProvider.KeyVersion retKeyVersion =
+          provider.decryptEncryptedKey(
+              new KMSClientProvider.KMSEncryptedKeyVersion(keyName, versionName,
+                  iv, KeyProviderCryptoExtension.EEK, encMaterial));
+      retJSON = KMSServerJSONUtils.toJSON(retKeyVersion);
+      KMSAudit.ok(user, DECRYPT_EEK, versionName, "");
+    } else {
+      throw new IllegalArgumentException("Wrong " + KMSRESTConstants.EEK_OP +
+          " value, it must be " + KMSRESTConstants.EEK_GENERATE + " or " +
+          KMSRESTConstants.EEK_DECRYPT);
+    }
+    KMSWebApp.getDecryptEEKCallsMeter().mark();
+    return Response.ok().type(MediaType.APPLICATION_JSON).entity(retJSON)
+        .build();
+  }
+
   @GET
   @Path(KMSRESTConstants.KEY_RESOURCE + "/{name:.*}/" +
       KMSRESTConstants.VERSIONS_SUB_RESOURCE)

+ 4 - 1
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSACLs.java

@@ -17,6 +17,7 @@
  */
 package org.apache.hadoop.crypto.key.kms.server;
 
+import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.authorize.AccessControlList;
@@ -34,12 +35,14 @@ import java.util.concurrent.TimeUnit;
  * hot-reloading them if the <code>kms-acls.xml</code> file where the ACLs
  * are defined has been updated.
  */
+@InterfaceAudience.Private
 public class KMSACLs implements Runnable {
   private static final Logger LOG = LoggerFactory.getLogger(KMSACLs.class);
 
 
   public enum Type {
-    CREATE, DELETE, ROLLOVER, GET, GET_KEYS, GET_METADATA, SET_KEY_MATERIAL;
+    CREATE, DELETE, ROLLOVER, GET, GET_KEYS, GET_METADATA,
+    SET_KEY_MATERIAL, GENERATE_EEK, DECRYPT_EEK;
 
     public String getConfigKey() {
       return KMSConfiguration.CONFIG_PREFIX + "acl." + this.toString();

+ 20 - 1
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSServerJSONUtils.java

@@ -17,8 +17,10 @@
  */
 package org.apache.hadoop.crypto.key.kms.server;
 
+import org.apache.commons.codec.binary.Base64;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
 import org.apache.hadoop.crypto.key.kms.KMSRESTConstants;
 
 import java.util.ArrayList;
@@ -39,7 +41,9 @@ public class KMSServerJSONUtils {
           keyVersion.getName());
       json.put(KMSRESTConstants.VERSION_NAME_FIELD,
           keyVersion.getVersionName());
-      json.put(KMSRESTConstants.MATERIAL_FIELD, keyVersion.getMaterial());
+      json.put(KMSRESTConstants.MATERIAL_FIELD,
+          Base64.encodeBase64URLSafeString(
+              keyVersion.getMaterial()));
     }
     return json;
   }
@@ -55,6 +59,21 @@ public class KMSServerJSONUtils {
     return json;
   }
 
+  @SuppressWarnings("unchecked")
+  public static Map toJSON(EncryptedKeyVersion encryptedKeyVersion) {
+    Map json = new LinkedHashMap();
+    if (encryptedKeyVersion != null) {
+      json.put(KMSRESTConstants.VERSION_NAME_FIELD,
+          encryptedKeyVersion.getKeyVersionName());
+      json.put(KMSRESTConstants.IV_FIELD,
+          Base64.encodeBase64URLSafeString(
+              encryptedKeyVersion.getIv()));
+      json.put(KMSRESTConstants.ENCRYPTED_KEY_VERSION_FIELD,
+          toJSON(encryptedKeyVersion.getEncryptedKey()));
+    }
+    return json;
+  }
+
   @SuppressWarnings("unchecked")
   public static Map toJSON(String keyName, KeyProvider.Metadata meta) {
     Map json = new LinkedHashMap();

+ 30 - 4
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSWebApp.java

@@ -20,10 +20,12 @@ package org.apache.hadoop.crypto.key.kms.server;
 import com.codahale.metrics.JmxReporter;
 import com.codahale.metrics.Meter;
 import com.codahale.metrics.MetricRegistry;
+
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.crypto.key.CachingKeyProvider;
 import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
 import org.apache.hadoop.crypto.key.KeyProviderFactory;
 import org.apache.hadoop.http.HttpServer2;
 import org.apache.hadoop.security.authorize.AccessControlList;
@@ -35,6 +37,7 @@ import org.slf4j.bridge.SLF4JBridgeHandler;
 
 import javax.servlet.ServletContextEvent;
 import javax.servlet.ServletContextListener;
+
 import java.io.File;
 import java.net.URL;
 import java.util.List;
@@ -55,6 +58,10 @@ public class KMSWebApp implements ServletContextListener {
       "unauthorized.calls.meter";
   private static final String UNAUTHENTICATED_CALLS_METER = METRICS_PREFIX +
       "unauthenticated.calls.meter";
+  private static final String GENERATE_EEK_METER = METRICS_PREFIX +
+      "generate_eek.calls.meter";
+  private static final String DECRYPT_EEK_METER = METRICS_PREFIX +
+      "decrypt_eek.calls.meter";
 
   private static Logger LOG;
   private static MetricRegistry metricRegistry;
@@ -66,8 +73,10 @@ public class KMSWebApp implements ServletContextListener {
   private static Meter keyCallsMeter;
   private static Meter unauthorizedCallsMeter;
   private static Meter unauthenticatedCallsMeter;
+  private static Meter decryptEEKCallsMeter;
+  private static Meter generateEEKCallsMeter;
   private static Meter invalidCallsMeter;
-  private static KeyProvider keyProvider;
+  private static KeyProviderCryptoExtension keyProviderCryptoExtension;
 
   static {
     SLF4JBridgeHandler.removeHandlersForRootLogger();
@@ -122,6 +131,10 @@ public class KMSWebApp implements ServletContextListener {
       metricRegistry = new MetricRegistry();
       jmxReporter = JmxReporter.forRegistry(metricRegistry).build();
       jmxReporter.start();
+      generateEEKCallsMeter = metricRegistry.register(GENERATE_EEK_METER,
+          new Meter());
+      decryptEEKCallsMeter = metricRegistry.register(DECRYPT_EEK_METER,
+          new Meter());
       adminCallsMeter = metricRegistry.register(ADMIN_CALLS_METER, new Meter());
       keyCallsMeter = metricRegistry.register(KEY_CALLS_METER, new Meter());
       invalidCallsMeter = metricRegistry.register(INVALID_CALLS_METER,
@@ -150,7 +163,7 @@ public class KMSWebApp implements ServletContextListener {
             "the first provider",
             kmsConf.get(KeyProviderFactory.KEY_PROVIDER_PATH));
       }
-      keyProvider = providers.get(0);
+      KeyProvider keyProvider = providers.get(0);
       if (kmsConf.getBoolean(KMSConfiguration.KEY_CACHE_ENABLE,
           KMSConfiguration.KEY_CACHE_ENABLE_DEFAULT)) {
         long keyTimeOutMillis =
@@ -162,6 +175,11 @@ public class KMSWebApp implements ServletContextListener {
         keyProvider = new CachingKeyProvider(keyProvider, keyTimeOutMillis,
             currKeyTimeOutMillis);
       }
+      keyProviderCryptoExtension = KeyProviderCryptoExtension.
+          createKeyProviderCryptoExtension(keyProvider);
+      keyProviderCryptoExtension = 
+          new EagerKeyGeneratorKeyProviderCryptoExtension(kmsConf, 
+              keyProviderCryptoExtension);
 
       LOG.info("KMS Started");
     } catch (Throwable ex) {
@@ -208,6 +226,14 @@ public class KMSWebApp implements ServletContextListener {
     return invalidCallsMeter;
   }
 
+  public static Meter getGenerateEEKCallsMeter() {
+    return generateEEKCallsMeter;
+  }
+
+  public static Meter getDecryptEEKCallsMeter() {
+    return decryptEEKCallsMeter;
+  }
+
   public static Meter getUnauthorizedCallsMeter() {
     return unauthorizedCallsMeter;
   }
@@ -216,7 +242,7 @@ public class KMSWebApp implements ServletContextListener {
     return unauthenticatedCallsMeter;
   }
 
-  public static KeyProvider getKeyProvider() {
-    return keyProvider;
+  public static KeyProviderCryptoExtension getKeyProvider() {
+    return keyProviderCryptoExtension;
   }
 }

+ 83 - 0
hadoop-common-project/hadoop-kms/src/site/apt/index.apt.vm

@@ -279,6 +279,25 @@ $ keytool -genkey -alias tomcat -keyalg RSA
         to provide the key material when creating or rolling a key.
     </description>
   </property>
+
+  <property>
+    <name>hadoop.kms.acl.GENERATE_EEK</name>
+    <value>*</value>
+    <description>
+      ACL for generateEncryptedKey
+      CryptoExtension operations
+    </description>
+  </property>
+
+  <property>
+    <name>hadoop.kms.acl.DECRYPT_EEK</name>
+    <value>*</value>
+    <description>
+      ACL for decrypt EncryptedKey
+      CryptoExtension operations
+    </description>
+  </property>
+</configuration>
 +---+
 
 ** KMS HTTP REST API
@@ -396,6 +415,70 @@ Content-Type: application/json
 }
 +---+
 
+
+*** Generate Encrypted Key for Current KeyVersion
+
+  <REQUEST:>
+
++---+
+GET http://HOST:PORT/kms/v1/key/<key-name>/_eek?eek_op=generate&num_keys=<number-of-keys-to-generate>
++---+
+
+  <RESPONSE:>
+
++---+
+200 OK
+Content-Type: application/json
+[
+  {
+    "versionName"         : "encryptionVersionName",
+    "iv"                  : "<iv>",          //base64
+    "encryptedKeyVersion" : {
+        "versionName"       : "EEK",
+        "material"          : "<material>",    //base64
+    }
+  },
+  {
+    "versionName"         : "encryptionVersionName",
+    "iv"                  : "<iv>",          //base64
+    "encryptedKeyVersion" : {
+        "versionName"       : "EEK",
+        "material"          : "<material>",    //base64
+    }
+  },
+  ...
+]
++---+
+
+*** Decrypt Encrypted Key
+
+  <REQUEST:>
+
++---+
+POST http://HOST:PORT/kms/v1/keyversion/<version-name>/_eek?ee_op=decrypt
+Content-Type: application/json
+
+{
+  "name"        : "<key-name>",
+  "iv"          : "<iv>",          //base64
+  "material"    : "<material>",    //base64
+}
+
++---+
+
+  <RESPONSE:>
+
++---+
+200 OK
+Content-Type: application/json
+
+{
+  "name"        : "EK",
+  "material"    : "<material>",    //base64
+}
++---+
+
+
 *** Get Key Version
 
   <REQUEST:>

+ 100 - 11
hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java

@@ -19,6 +19,9 @@ package org.apache.hadoop.crypto.key.kms.server;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProvider.KeyVersion;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
 import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
 import org.apache.hadoop.minikdc.MiniKdc;
 import org.apache.hadoop.security.authorize.AuthorizationException;
@@ -36,6 +39,7 @@ import javax.security.auth.Subject;
 import javax.security.auth.kerberos.KerberosPrincipal;
 import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.auth.login.LoginContext;
+
 import java.io.File;
 import java.io.FileWriter;
 import java.io.IOException;
@@ -267,7 +271,7 @@ public class TestKMS {
     }
   }
 
-  private void doAs(String user, final PrivilegedExceptionAction<Void> action)
+  private <T> T doAs(String user, final PrivilegedExceptionAction<T> action)
       throws Exception {
     Set<Principal> principals = new HashSet<Principal>();
     principals.add(new KerberosPrincipal(user));
@@ -280,7 +284,7 @@ public class TestKMS {
     try {
       loginContext.login();
       subject = loginContext.getSubject();
-      Subject.doAs(subject, action);
+      return Subject.doAs(subject, action);
     } finally {
       loginContext.logout();
     }
@@ -474,6 +478,32 @@ public class TestKMS {
         Assert.assertNotNull(kms1[0].getCreated());
         Assert.assertTrue(started.before(kms1[0].getCreated()));
 
+        // test generate and decryption of EEK
+        KeyProvider.KeyVersion kv = kp.getCurrentKey("k1");
+        KeyProviderCryptoExtension kpExt =
+            KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp);
+
+        EncryptedKeyVersion ek1 = kpExt.generateEncryptedKey(kv.getName());
+        Assert.assertEquals(KeyProviderCryptoExtension.EEK,
+            ek1.getEncryptedKey().getVersionName());
+        Assert.assertNotNull(ek1.getEncryptedKey().getMaterial());
+        Assert.assertEquals(kv.getMaterial().length,
+            ek1.getEncryptedKey().getMaterial().length);
+        KeyProvider.KeyVersion k1 = kpExt.decryptEncryptedKey(ek1);
+        Assert.assertEquals(KeyProviderCryptoExtension.EK, k1.getVersionName());
+        KeyProvider.KeyVersion k1a = kpExt.decryptEncryptedKey(ek1);
+        Assert.assertArrayEquals(k1.getMaterial(), k1a.getMaterial());
+        Assert.assertEquals(kv.getMaterial().length, k1.getMaterial().length);
+
+        EncryptedKeyVersion ek2 = kpExt.generateEncryptedKey(kv.getName());
+        KeyProvider.KeyVersion k2 = kpExt.decryptEncryptedKey(ek2);
+        boolean isEq = true;
+        for (int i = 0; isEq && i < ek2.getEncryptedKey().getMaterial().length;
+            i++) {
+          isEq = k2.getMaterial()[i] == k1.getMaterial()[i];
+        }
+        Assert.assertFalse(isEq);
+
         // deleteKey()
         kp.deleteKey("k1");
 
@@ -565,7 +595,7 @@ public class TestKMS {
       @Override
       public Void call() throws Exception {
         final Configuration conf = new Configuration();
-        conf.setInt(KeyProvider.DEFAULT_BITLENGTH_NAME, 64);
+        conf.setInt(KeyProvider.DEFAULT_BITLENGTH_NAME, 128);
         URI uri = createKMSUri(getKMSUrl());
         final KeyProvider kp = new KMSClientProvider(uri, conf);
 
@@ -582,7 +612,7 @@ public class TestKMS {
               Assert.fail(ex.toString());
             }
             try {
-              kp.createKey("k", new byte[8], new KeyProvider.Options(conf));
+              kp.createKey("k", new byte[16], new KeyProvider.Options(conf));
               Assert.fail();
             } catch (AuthorizationException ex) {
               //NOP
@@ -598,7 +628,7 @@ public class TestKMS {
               Assert.fail(ex.toString());
             }
             try {
-              kp.rollNewVersion("k", new byte[8]);
+              kp.rollNewVersion("k", new byte[16]);
               Assert.fail();
             } catch (AuthorizationException ex) {
               //NOP
@@ -690,7 +720,7 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KeyProvider.KeyVersion kv = kp.createKey("k1", new byte[8],
+              KeyProvider.KeyVersion kv = kp.createKey("k1", new byte[16],
                   new KeyProvider.Options(conf));
               Assert.assertNull(kv.getMaterial());
             } catch (Exception ex) {
@@ -717,7 +747,8 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KeyProvider.KeyVersion kv = kp.rollNewVersion("k1", new byte[8]);
+              KeyProvider.KeyVersion kv =
+                  kp.rollNewVersion("k1", new byte[16]);
               Assert.assertNull(kv.getMaterial());
             } catch (Exception ex) {
               Assert.fail(ex.toString());
@@ -726,12 +757,46 @@ public class TestKMS {
           }
         });
 
-        doAs("GET", new PrivilegedExceptionAction<Void>() {
+        final KeyVersion currKv =
+            doAs("GET", new PrivilegedExceptionAction<KeyVersion>() {
           @Override
-          public Void run() throws Exception {
+          public KeyVersion run() throws Exception {
             try {
               kp.getKeyVersion("k1@0");
-              kp.getCurrentKey("k1");
+              KeyVersion kv = kp.getCurrentKey("k1");
+              return kv;
+            } catch (Exception ex) {
+              Assert.fail(ex.toString());
+            }
+            return null;
+          }
+        });
+
+        final EncryptedKeyVersion encKv =
+            doAs("GENERATE_EEK",
+                new PrivilegedExceptionAction<EncryptedKeyVersion>() {
+          @Override
+          public EncryptedKeyVersion run() throws Exception {
+            try {
+              KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
+                      createKeyProviderCryptoExtension(kp);
+              EncryptedKeyVersion ek1 =
+                  kpCE.generateEncryptedKey(currKv.getName());
+              return ek1;
+            } catch (Exception ex) {
+              Assert.fail(ex.toString());
+            }
+            return null;
+          }
+        });
+
+        doAs("DECRYPT_EEK", new PrivilegedExceptionAction<Void>() {
+          @Override
+          public Void run() throws Exception {
+            try {
+              KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
+                      createKeyProviderCryptoExtension(kp);
+              kpCE.decryptEncryptedKey(encKv);
             } catch (Exception ex) {
               Assert.fail(ex.toString());
             }
@@ -817,7 +882,7 @@ public class TestKMS {
       @Override
       public Void call() throws Exception {
         final Configuration conf = new Configuration();
-        conf.setInt(KeyProvider.DEFAULT_BITLENGTH_NAME, 64);
+        conf.setInt(KeyProvider.DEFAULT_BITLENGTH_NAME, 128);
         URI uri = createKMSUri(getKMSUrl());
         final KeyProvider kp = new KMSClientProvider(uri, conf);
 
@@ -889,6 +954,30 @@ public class TestKMS {
       Assert.assertTrue("Caught unexpected exception" + e.toString(), false);
     }
 
+    caughtTimeout = false;
+    try {
+      KeyProvider kp = new KMSClientProvider(uri, conf);
+      KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
+          .generateEncryptedKey("a");
+    } catch (SocketTimeoutException e) {
+      caughtTimeout = true;
+    } catch (IOException e) {
+      Assert.assertTrue("Caught unexpected exception" + e.toString(), false);
+    }
+
+    caughtTimeout = false;
+    try {
+      KeyProvider kp = new KMSClientProvider(uri, conf);
+      KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
+          .decryptEncryptedKey(
+              new KMSClientProvider.KMSEncryptedKeyVersion("a",
+                  "a", new byte[] {1, 2}, "EEK", new byte[] {1, 2}));
+    } catch (SocketTimeoutException e) {
+      caughtTimeout = true;
+    } catch (IOException e) {
+      Assert.assertTrue("Caught unexpected exception" + e.toString(), false);
+    }
+
     Assert.assertTrue(caughtTimeout);
 
     sock.close();