Jelajahi Sumber

HDFS-11210. Enhance key rolling to guarantee new KeyVersion is returned from generateEncryptedKeys after a key is rolled.

Xiao Chen 8 tahun lalu
induk
melakukan
2007e0cf2a
16 mengubah file dengan 375 tambahan dan 66 penghapusan
  1. 11 3
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/CachingKeyProvider.java
  2. 12 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProvider.java
  3. 5 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderExtension.java
  4. 66 1
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyShell.java
  5. 12 1
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
  6. 1 0
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSRESTConstants.java
  7. 19 5
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
  8. 102 15
      hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/ValueQueue.java
  9. 9 0
      hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyShell.java
  10. 6 0
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/EagerKeyGeneratorKeyProviderCryptoExtension.java
  11. 32 1
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMS.java
  12. 11 0
      hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KeyAuthorizationKeyProvider.java
  13. 13 1
      hadoop-common-project/hadoop-kms/src/site/markdown/index.md.vm
  14. 72 17
      hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
  15. 2 0
      hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSAudit.java
  16. 2 22
      hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZones.java

+ 11 - 3
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/CachingKeyProvider.java

@@ -141,8 +141,7 @@ public class CachingKeyProvider extends
   public KeyVersion rollNewVersion(String name, byte[] material)
       throws IOException {
     KeyVersion key = getKeyProvider().rollNewVersion(name, material);
-    getExtension().currentKeyCache.invalidate(name);
-    getExtension().keyMetadataCache.invalidate(name);
+    invalidateCache(name);
     return key;
   }
 
@@ -150,9 +149,18 @@ public class CachingKeyProvider extends
   public KeyVersion rollNewVersion(String name)
       throws NoSuchAlgorithmException, IOException {
     KeyVersion key = getKeyProvider().rollNewVersion(name);
+    invalidateCache(name);
+    return key;
+  }
+
+  @Override
+  public void invalidateCache(String name) throws IOException {
+    getKeyProvider().invalidateCache(name);
     getExtension().currentKeyCache.invalidate(name);
     getExtension().keyMetadataCache.invalidate(name);
-    return key;
+    // invalidating all key versions as we don't know
+    // which ones belonged to the deleted key
+    getExtension().keyVersionCache.invalidateAll();
   }
 
   @Override

+ 12 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProvider.java

@@ -593,6 +593,18 @@ public abstract class KeyProvider {
     return rollNewVersion(name, material);
   }
 
+  /**
+   * Can be used by implementing classes to invalidate the caches. This could be
+   * used after rollNewVersion to provide a strong guarantee to return the new
+   * version of the given key.
+   *
+   * @param name the basename of the key
+   * @throws IOException
+   */
+  public void invalidateCache(String name) throws IOException {
+    // NOP
+  }
+
   /**
    * Ensures that any changes to the keys are written to persistent store.
    * @throws IOException

+ 5 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderExtension.java

@@ -117,6 +117,11 @@ public abstract class KeyProviderExtension
     return keyProvider.rollNewVersion(name, material);
   }
 
+  @Override
+  public void invalidateCache(String name) throws IOException {
+    keyProvider.invalidateCache(name);
+  }
+
   @Override
   public void flush() throws IOException {
     keyProvider.flush();

+ 66 - 1
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyShell.java

@@ -46,7 +46,8 @@ public class KeyShell extends CommandShell {
       "   [" + CreateCommand.USAGE + "]\n" +
       "   [" + RollCommand.USAGE + "]\n" +
       "   [" + DeleteCommand.USAGE + "]\n" +
-      "   [" + ListCommand.USAGE + "]\n";
+      "   [" + ListCommand.USAGE + "]\n" +
+      "   [" + InvalidateCacheCommand.USAGE + "]\n";
   private static final String LIST_METADATA = "keyShell.list.metadata";
   @VisibleForTesting
   public static final String NO_VALID_PROVIDERS =
@@ -70,6 +71,7 @@ public class KeyShell extends CommandShell {
    * % hadoop key roll keyName [-provider providerPath]
    * % hadoop key list [-provider providerPath]
    * % hadoop key delete keyName [-provider providerPath] [-i]
+   * % hadoop key invalidateCache keyName [-provider providerPath]
    * </pre>
    * @param args Command line arguments.
    * @return 0 on success, 1 on failure.
@@ -111,6 +113,15 @@ public class KeyShell extends CommandShell {
         }
       } else if ("list".equals(args[i])) {
         setSubCommand(new ListCommand());
+      } else if ("invalidateCache".equals(args[i])) {
+        String keyName = "-help";
+        if (moreTokens) {
+          keyName = args[++i];
+        }
+        setSubCommand(new InvalidateCacheCommand(keyName));
+        if ("-help".equals(keyName)) {
+          return 1;
+        }
       } else if ("-size".equals(args[i]) && moreTokens) {
         options.setBitLength(Integer.parseInt(args[++i]));
       } else if ("-cipher".equals(args[i]) && moreTokens) {
@@ -168,6 +179,9 @@ public class KeyShell extends CommandShell {
     sbuf.append(DeleteCommand.USAGE + ":\n\n" + DeleteCommand.DESC + "\n");
     sbuf.append(banner + "\n");
     sbuf.append(ListCommand.USAGE + ":\n\n" + ListCommand.DESC + "\n");
+    sbuf.append(banner + "\n");
+    sbuf.append(InvalidateCacheCommand.USAGE + ":\n\n"
+        + InvalidateCacheCommand.DESC + "\n");
     return sbuf.toString();
   }
 
@@ -466,6 +480,57 @@ public class KeyShell extends CommandShell {
     }
   }
 
+  private class InvalidateCacheCommand extends Command {
+    public static final String USAGE =
+        "invalidateCache <keyname> [-provider <provider>] [-help]";
+    public static final String DESC =
+        "The invalidateCache subcommand invalidates the cached key versions\n"
+            + "of the specified key, on the provider indicated using the"
+            + " -provider argument.\n";
+
+    private String keyName = null;
+
+    InvalidateCacheCommand(String keyName) {
+      this.keyName = keyName;
+    }
+
+    public boolean validate() {
+      boolean rc = true;
+      provider = getKeyProvider();
+      if (provider == null) {
+        getOut().println("Invalid provider.");
+        rc = false;
+      }
+      if (keyName == null) {
+        getOut().println("Please provide a <keyname>.\n" +
+            "See the usage description by using -help.");
+        rc = false;
+      }
+      return rc;
+    }
+
+    public void execute() throws NoSuchAlgorithmException, IOException {
+      try {
+        warnIfTransientProvider();
+        getOut().println("Invalidating cache on KeyProvider: "
+            + provider + "\n  for key name: " + keyName);
+        provider.invalidateCache(keyName);
+        getOut().println("Cached keyversions of " + keyName
+            + " has been successfully invalidated.");
+        printProviderWritten();
+      } catch (IOException e) {
+        getOut().println("Cannot invalidate cache for key: " + keyName +
+            " within KeyProvider: " + provider + ". " + e.toString());
+        throw e;
+      }
+    }
+
+    @Override
+    public String getUsage() {
+      return USAGE + ":\n\n" + DESC;
+    }
+  }
+
   /**
    * main() entry point for the KeyShell.  While strictly speaking the
    * return is void, it will System.exit() with a return code: 0 is for

+ 12 - 1
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java

@@ -757,6 +757,17 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
     }
   }
 
+  @Override
+  public void invalidateCache(String name) throws IOException {
+    checkNotEmpty(name, "name");
+    final URL url = createURL(KMSRESTConstants.KEY_RESOURCE, name,
+        KMSRESTConstants.INVALIDATECACHE_RESOURCE, null);
+    final HttpURLConnection conn = createConnection(url, HTTP_POST);
+    // invalidate the server cache first, then drain local cache.
+    call(conn, null, HttpURLConnection.HTTP_OK, null);
+    drain(name);
+  }
+
   private KeyVersion rollNewVersionInternal(String name, byte[] material)
       throws NoSuchAlgorithmException, IOException {
     checkNotEmpty(name, "name");
@@ -771,7 +782,7 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
     Map response = call(conn, jsonMaterial,
         HttpURLConnection.HTTP_OK, Map.class);
     KeyVersion keyVersion = parseJSONKeyVersion(response);
-    encKeyVersionQueue.drain(name);
+    invalidateCache(name);
     return keyVersion;
   }
 

+ 1 - 0
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSRESTConstants.java

@@ -36,6 +36,7 @@ public class KMSRESTConstants {
   public static final String VERSIONS_SUB_RESOURCE = "_versions";
   public static final String EEK_SUB_RESOURCE = "_eek";
   public static final String CURRENT_VERSION_SUB_RESOURCE = "_currentversion";
+  public static final String INVALIDATECACHE_RESOURCE = "_invalidatecache";
 
   public static final String KEY = "key";
   public static final String EEK_OP = "eek_op";

+ 19 - 5
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java

@@ -178,6 +178,14 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements
     }
   }
 
+  // This request is sent to all providers in the load-balancing group
+  @Override
+  public void invalidateCache(String keyName) throws IOException {
+    for (KMSClientProvider provider : providers) {
+      provider.invalidateCache(keyName);
+    }
+  }
+
   @Override
   public EncryptedKeyVersion
       generateEncryptedKey(final String encryptionKeyName)
@@ -218,14 +226,14 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements
     }
   }
 
-  public EncryptedKeyVersion reencryptEncryptedKey(EncryptedKeyVersion edek)
+  public EncryptedKeyVersion reencryptEncryptedKey(EncryptedKeyVersion ekv)
       throws IOException, GeneralSecurityException {
     try {
       return doOp(new ProviderCallable<EncryptedKeyVersion>() {
         @Override
         public EncryptedKeyVersion call(KMSClientProvider provider)
             throws IOException, GeneralSecurityException {
-          return provider.reencryptEncryptedKey(edek);
+          return provider.reencryptEncryptedKey(ekv);
         }
       }, nextIdx());
     } catch (WrapperException we) {
@@ -325,6 +333,7 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements
       throw new IOException(e.getCause());
     }
   }
+
   @Override
   public void deleteKey(final String name) throws IOException {
     doOp(new ProviderCallable<Void>() {
@@ -335,28 +344,33 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements
       }
     }, nextIdx());
   }
+
   @Override
   public KeyVersion rollNewVersion(final String name, final byte[] material)
       throws IOException {
-    return doOp(new ProviderCallable<KeyVersion>() {
+    final KeyVersion newVersion = doOp(new ProviderCallable<KeyVersion>() {
       @Override
       public KeyVersion call(KMSClientProvider provider) throws IOException {
         return provider.rollNewVersion(name, material);
       }
     }, nextIdx());
+    invalidateCache(name);
+    return newVersion;
   }
 
   @Override
   public KeyVersion rollNewVersion(final String name)
       throws NoSuchAlgorithmException, IOException {
     try {
-      return doOp(new ProviderCallable<KeyVersion>() {
+      final KeyVersion newVersion = doOp(new ProviderCallable<KeyVersion>() {
         @Override
         public KeyVersion call(KMSClientProvider provider) throws IOException,
-        NoSuchAlgorithmException {
+            NoSuchAlgorithmException {
           return provider.rollNewVersion(name);
         }
       }, nextIdx());
+      invalidateCache(name);
+      return newVersion;
     } catch (WrapperException e) {
       if (e.getCause() instanceof GeneralSecurityException) {
         throw (NoSuchAlgorithmException) e.getCause();

+ 102 - 15
hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/ValueQueue.java

@@ -18,8 +18,9 @@
 package org.apache.hadoop.crypto.key.kms;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.HashSet;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -28,6 +29,9 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.base.Preconditions;
 import com.google.common.cache.CacheBuilder;
@@ -67,8 +71,17 @@ public class ValueQueue <E> {
 
   private static final String REFILL_THREAD =
       ValueQueue.class.getName() + "_thread";
+  private static final int LOCK_ARRAY_SIZE = 16;
+  // Using a mask assuming array size is the power of 2, of MAX_VALUE.
+  private static final int MASK = LOCK_ARRAY_SIZE == Integer.MAX_VALUE ?
+      LOCK_ARRAY_SIZE :
+      LOCK_ARRAY_SIZE - 1;
 
   private final LoadingCache<String, LinkedBlockingQueue<E>> keyQueues;
+  // Stripped rwlocks based on key name to synchronize the queue from
+  // the sync'ed rw-thread and the background async refill thread.
+  private final List<ReadWriteLock> lockArray =
+      new ArrayList<>(LOCK_ARRAY_SIZE);
   private final ThreadPoolExecutor executor;
   private final UniqueKeyBlockingQueue queue = new UniqueKeyBlockingQueue();
   private final QueueRefiller<E> refiller;
@@ -84,9 +97,47 @@ public class ValueQueue <E> {
    */
   private abstract static class NamedRunnable implements Runnable {
     final String name;
+    private AtomicBoolean canceled = new AtomicBoolean(false);
     private NamedRunnable(String keyName) {
       this.name = keyName;
     }
+
+    public void cancel() {
+      canceled.set(true);
+    }
+
+    public boolean isCanceled() {
+      return canceled.get();
+    }
+  }
+
+  private void readLock(String keyName) {
+    getLock(keyName).readLock().lock();
+  }
+
+  private void readUnlock(String keyName) {
+    getLock(keyName).readLock().unlock();
+  }
+
+  private void writeUnlock(String keyName) {
+    getLock(keyName).writeLock().unlock();
+  }
+
+  private void writeLock(String keyName) {
+    getLock(keyName).writeLock().lock();
+  }
+
+  /**
+   * Get the stripped lock given a key name.
+   *
+   * @param keyName The key name.
+   */
+  private ReadWriteLock getLock(String keyName) {
+    return lockArray.get(indexFor(keyName));
+  }
+
+  private static int indexFor(String keyName) {
+    return keyName.hashCode() & MASK;
   }
 
   /**
@@ -103,11 +154,12 @@ public class ValueQueue <E> {
       LinkedBlockingQueue<Runnable> {
 
     private static final long serialVersionUID = -2152747693695890371L;
-    private HashSet<String> keysInProgress = new HashSet<String>();
+    private HashMap<String, Runnable> keysInProgress = new HashMap<>();
 
     @Override
     public synchronized void put(Runnable e) throws InterruptedException {
-      if (keysInProgress.add(((NamedRunnable)e).name)) {
+      if (!keysInProgress.containsKey(((NamedRunnable)e).name)) {
+        keysInProgress.put(((NamedRunnable)e).name, e);
         super.put(e);
       }
     }
@@ -131,6 +183,14 @@ public class ValueQueue <E> {
       return k;
     }
 
+    public Runnable deleteByName(String name) {
+      NamedRunnable e = (NamedRunnable) keysInProgress.remove(name);
+      if (e != null) {
+        e.cancel();
+        super.remove(e);
+      }
+      return e;
+    }
   }
 
   /**
@@ -172,6 +232,9 @@ public class ValueQueue <E> {
     this.policy = policy;
     this.numValues = numValues;
     this.lowWatermark = lowWatermark;
+    for (int i = 0; i < LOCK_ARRAY_SIZE; ++i) {
+      lockArray.add(i, new ReentrantReadWriteLock());
+    }
     keyQueues = CacheBuilder.newBuilder()
             .expireAfterAccess(expiry, TimeUnit.MILLISECONDS)
             .build(new CacheLoader<String, LinkedBlockingQueue<E>>() {
@@ -233,9 +296,18 @@ public class ValueQueue <E> {
    *
    * @param keyName the key to drain the Queue for
    */
-  public void drain(String keyName ) {
+  public void drain(String keyName) {
     try {
-      keyQueues.get(keyName).clear();
+      Runnable e;
+      while ((e = queue.deleteByName(keyName)) != null) {
+        executor.remove(e);
+      }
+      writeLock(keyName);
+      try {
+        keyQueues.get(keyName).clear();
+      } finally {
+        writeUnlock(keyName);
+      }
     } catch (ExecutionException ex) {
       //NOP
     }
@@ -247,14 +319,19 @@ public class ValueQueue <E> {
    * @return int queue size
    */
   public int getSize(String keyName) {
-    // We can't do keyQueues.get(keyName).size() here,
-    // since that will have the side effect of populating the cache.
-    Map<String, LinkedBlockingQueue<E>> map =
-        keyQueues.getAllPresent(Arrays.asList(keyName));
-    if (map.get(keyName) == null) {
-      return 0;
+    readLock(keyName);
+    try {
+      // We can't do keyQueues.get(keyName).size() here,
+      // since that will have the side effect of populating the cache.
+      Map<String, LinkedBlockingQueue<E>> map =
+          keyQueues.getAllPresent(Arrays.asList(keyName));
+      if (map.get(keyName) == null) {
+        return 0;
+      }
+      return map.get(keyName).size();
+    } finally {
+      readUnlock(keyName);
     }
-    return map.get(keyName).size();
   }
 
   /**
@@ -276,7 +353,9 @@ public class ValueQueue <E> {
     LinkedList<E> ekvs = new LinkedList<E>();
     try {
       for (int i = 0; i < num; i++) {
+        readLock(keyName);
         E val = keyQueue.poll();
+        readUnlock(keyName);
         // If queue is empty now, Based on the provided SyncGenerationPolicy,
         // figure out how many new values need to be generated synchronously
         if (val == null) {
@@ -336,9 +415,17 @@ public class ValueQueue <E> {
             int threshold = (int) (lowWatermark * (float) cacheSize);
             // Need to ensure that only one refill task per key is executed
             try {
-              if (keyQueue.size() < threshold) {
-                refiller.fillQueueForKey(name, keyQueue,
-                    cacheSize - keyQueue.size());
+              writeLock(keyName);
+              try {
+                if (keyQueue.size() < threshold && !isCanceled()) {
+                  refiller.fillQueueForKey(name, keyQueue,
+                      cacheSize - keyQueue.size());
+                }
+                if (isCanceled()) {
+                  keyQueue.clear();
+                }
+              } finally {
+                writeUnlock(keyName);
               }
             } catch (final Exception e) {
               throw new RuntimeException(e);

+ 9 - 0
hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyShell.java

@@ -138,6 +138,15 @@ public class TestKeyShell {
     assertTrue(outContent.toString().contains("key1 has been successfully " +
         "rolled."));
 
+    // jceks provider's invalidate is a no-op.
+    outContent.reset();
+    final String[] args3 =
+        {"invalidateCache", keyName, "-provider", jceksProvider};
+    rc = ks.run(args3);
+    assertEquals(0, rc);
+    assertTrue(outContent.toString()
+        .contains("key1 has been successfully " + "invalidated."));
+
     deleteKey(ks, keyName);
 
     listOut = listKeys(ks, false);

+ 6 - 0
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/EagerKeyGeneratorKeyProviderCryptoExtension.java

@@ -183,4 +183,10 @@ public class EagerKeyGeneratorKeyProviderCryptoExtension
     getExtension().drain(name);
     return keyVersion;
   }
+
+  @Override
+  public void invalidateCache(String name) throws IOException {
+    super.invalidateCache(name);
+    getExtension().drain(name);
+  }
 }

+ 32 - 1
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMS.java

@@ -61,7 +61,7 @@ import java.util.Map;
 public class KMS {
 
   public static enum KMSOp {
-    CREATE_KEY, DELETE_KEY, ROLL_NEW_VERSION,
+    CREATE_KEY, DELETE_KEY, ROLL_NEW_VERSION, INVALIDATE_CACHE,
     GET_KEYS, GET_KEYS_METADATA,
     GET_KEY_VERSIONS, GET_METADATA, GET_KEY_VERSION, GET_CURRENT_KEY,
     GENERATE_EEK, DECRYPT_EEK, REENCRYPT_EEK
@@ -252,6 +252,37 @@ public class KMS {
     }
   }
 
+  @POST
+  @Path(KMSRESTConstants.KEY_RESOURCE + "/{name:.*}/"
+      + KMSRESTConstants.INVALIDATECACHE_RESOURCE)
+  public Response invalidateCache(@PathParam("name") final String name)
+      throws Exception {
+    try {
+      LOG.trace("Entering invalidateCache Method.");
+      KMSWebApp.getAdminCallsMeter().mark();
+      KMSClientProvider.checkNotEmpty(name, "name");
+      UserGroupInformation user = HttpUserGroupInformation.get();
+      assertAccess(KMSACLs.Type.ROLLOVER, user, KMSOp.INVALIDATE_CACHE, name);
+      LOG.debug("Invalidating cache with key name {}.", name);
+
+      user.doAs(new PrivilegedExceptionAction<Void>() {
+        @Override
+        public Void run() throws Exception {
+          provider.invalidateCache(name);
+          provider.flush();
+          return null;
+        }
+      });
+
+      kmsAudit.ok(user, KMSOp.INVALIDATE_CACHE, name, "");
+      LOG.trace("Exiting invalidateCache for key name {}.", name);
+      return Response.ok().build();
+    } catch (Exception e) {
+      LOG.debug("Exception in invalidateCache for key name {}.", name, e);
+      throw e;
+    }
+  }
+
   @GET
   @Path(KMSRESTConstants.KEYS_METADATA_RESOURCE)
   @Produces(MediaType.APPLICATION_JSON + "; " + JettyUtils.UTF_8)

+ 11 - 0
hadoop-common-project/hadoop-kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KeyAuthorizationKeyProvider.java

@@ -210,6 +210,17 @@ public class KeyAuthorizationKeyProvider extends KeyProviderCryptoExtension {
     }
   }
 
+  @Override
+  public void invalidateCache(String name) throws IOException {
+    writeLock.lock();
+    try {
+      doAccessCheck(name, KeyOpType.MANAGEMENT);
+      provider.invalidateCache(name);
+    } finally {
+      writeLock.unlock();
+    }
+  }
+
   @Override
   public void warmUpEncryptedKeys(String... names) throws IOException {
     readLock.lock();

+ 13 - 1
hadoop-common-project/hadoop-kms/src/site/markdown/index.md.vm

@@ -103,7 +103,9 @@ This cache is used with the following 3 methods only, `getCurrentKey()` and `get
 
 For the `getCurrentKey()` method, cached entries are kept for a maximum of 30000 milliseconds regardless the number of times the key is being accessed (to avoid stale keys to be considered current).
 
-For the `getKeyVersion()` method, cached entries are kept with a default inactivity timeout of 600000 milliseconds (10 mins).
+For the `getKeyVersion()` and `getMetadata()` methods, cached entries are kept with a default inactivity timeout of 600000 milliseconds (10 mins).
+
+The cache is invalidated when the key is deleted by `deleteKey()`, or when `invalidateCache()` is called.
 
 These configurations can be changed via the following properties in the `etc/hadoop/kms-site.xml` configuration file:
 
@@ -841,6 +843,16 @@ $H4 Rollover Key
       "material"    : "<material>",    //base64, not present without GET ACL
     }
 
+$H4 Invalidate Cache of a Key
+
+*REQUEST:*
+
+    POST http://HOST:PORT/kms/v1/key/<key-name>/_invalidatecache
+
+*RESPONSE:*
+
+    200 OK
+
 $H4 Delete Key
 
 *REQUEST:*

+ 72 - 17
hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java

@@ -18,6 +18,7 @@
 package org.apache.hadoop.crypto.key.kms.server;
 
 import com.google.common.base.Supplier;
+import com.google.common.cache.LoadingCache;
 import org.apache.curator.test.TestingServer;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.crypto.key.KeyProviderFactory;
@@ -31,7 +32,7 @@ import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
 import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
 import org.apache.hadoop.crypto.key.kms.KMSDelegationToken;
 import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider;
-import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.apache.hadoop.crypto.key.kms.ValueQueue;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.minikdc.MiniKdc;
 import org.apache.hadoop.security.Credentials;
@@ -49,6 +50,8 @@ import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.Timeout;
+import org.mockito.Mockito;
+import org.mockito.internal.util.reflection.Whitebox;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -79,11 +82,14 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.Callable;
+import java.util.concurrent.LinkedBlockingQueue;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Mockito.when;
 
 public class TestKMS {
   private static final Logger LOG = LoggerFactory.getLogger(TestKMS.class);
@@ -128,6 +134,11 @@ public class TestKMS {
         new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf);
   }
 
+  private KMSClientProvider createKMSClientProvider(URI uri, Configuration conf)
+      throws IOException {
+    return new KMSClientProvider(uri, conf);
+  }
+
   protected <T> T runServer(String keystore, String password, File confDir,
       KMSCallable<T> callable) throws Exception {
     return runServer(-1, keystore, password, confDir, callable);
@@ -723,24 +734,68 @@ public class TestKMS {
 
         EncryptedKeyVersion ekv1 = kpce.generateEncryptedKey("k6");
         kpce.rollNewVersion("k6");
+        kpce.invalidateCache("k6");
+        EncryptedKeyVersion ekv2 = kpce.generateEncryptedKey("k6");
+        assertNotEquals("rollover did not generate a new key even after"
+            + " queue is drained", ekv1.getEncryptionKeyVersionName(),
+            ekv2.getEncryptionKeyVersionName());
+        return null;
+      }
+    });
+  }
 
-        /**
-         * due to the cache on the server side, client may get old keys.
-         * @see EagerKeyGeneratorKeyProviderCryptoExtension#rollNewVersion(String)
-         */
-        boolean rollSucceeded = false;
-        for (int i = 0; i <= EagerKeyGeneratorKeyProviderCryptoExtension
-            .KMS_KEY_CACHE_SIZE_DEFAULT + CommonConfigurationKeysPublic.
-            KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT; ++i) {
-          EncryptedKeyVersion ekv2 = kpce.generateEncryptedKey("k6");
-          if (!(ekv1.getEncryptionKeyVersionName()
-              .equals(ekv2.getEncryptionKeyVersionName()))) {
-            rollSucceeded = true;
-            break;
-          }
+  @Test
+  public void testKMSProviderCaching() throws Exception {
+    Configuration conf = new Configuration();
+    File confDir = getTestDir();
+    conf = createBaseKMSConf(confDir, conf);
+    conf.set(KeyAuthorizationKeyProvider.KEY_ACL + "k1.ALL", "*");
+    writeConf(confDir, conf);
+
+    runServer(null, null, confDir, new KMSCallable<Void>() {
+      @Override
+      public Void call() throws Exception {
+        final String keyName = "k1";
+        final String mockVersionName = "mock";
+        final Configuration conf = new Configuration();
+        final URI uri = createKMSUri(getKMSUrl());
+        KMSClientProvider kmscp = createKMSClientProvider(uri, conf);
+
+        // get the reference to the internal cache, to test invalidation.
+        ValueQueue vq =
+            (ValueQueue) Whitebox.getInternalState(kmscp, "encKeyVersionQueue");
+        LoadingCache<String, LinkedBlockingQueue<EncryptedKeyVersion>> kq =
+            ((LoadingCache<String, LinkedBlockingQueue<EncryptedKeyVersion>>)
+                Whitebox.getInternalState(vq, "keyQueues"));
+        EncryptedKeyVersion mockEKV = Mockito.mock(EncryptedKeyVersion.class);
+        when(mockEKV.getEncryptionKeyName()).thenReturn(keyName);
+        when(mockEKV.getEncryptionKeyVersionName()).thenReturn(mockVersionName);
+
+        // createKey()
+        KeyProvider.Options options = new KeyProvider.Options(conf);
+        options.setCipher("AES/CTR/NoPadding");
+        options.setBitLength(128);
+        options.setDescription("l1");
+        KeyProvider.KeyVersion kv0 = kmscp.createKey(keyName, options);
+        assertNotNull(kv0.getVersionName());
+
+        assertEquals("Default key version name is incorrect.", "k1@0",
+            kmscp.generateEncryptedKey(keyName).getEncryptionKeyVersionName());
+
+        kmscp.invalidateCache(keyName);
+        kq.get(keyName).put(mockEKV);
+        assertEquals("Key version incorrect after invalidating cache + putting"
+                + " mock key.", mockVersionName,
+            kmscp.generateEncryptedKey(keyName).getEncryptionKeyVersionName());
+
+        // test new version is returned after invalidation.
+        for (int i = 0; i < 100; ++i) {
+          kq.get(keyName).put(mockEKV);
+          kmscp.invalidateCache(keyName);
+          assertEquals("Cache invalidation guarantee failed.", "k1@0",
+              kmscp.generateEncryptedKey(keyName)
+                  .getEncryptionKeyVersionName());
         }
-        Assert.assertTrue("rollover did not generate a new key even after"
-            + " queue is drained", rollSucceeded);
         return null;
       }
     });

+ 2 - 0
hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSAudit.java

@@ -104,6 +104,7 @@ public class TestKMSAudit {
     kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
     kmsAudit.ok(luser, KMSOp.DELETE_KEY, "k1", "testmsg");
     kmsAudit.ok(luser, KMSOp.ROLL_NEW_VERSION, "k1", "testmsg");
+    kmsAudit.ok(luser, KMSOp.INVALIDATE_CACHE, "k1", "testmsg");
     kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
     kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
     kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
@@ -122,6 +123,7 @@ public class TestKMSAudit {
             // Not aggregated !!
             + "OK\\[op=DELETE_KEY, key=k1, user=luser\\] testmsg"
             + "OK\\[op=ROLL_NEW_VERSION, key=k1, user=luser\\] testmsg"
+            + "OK\\[op=INVALIDATE_CACHE, key=k1, user=luser\\] testmsg"
             // Aggregated
             + "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=6, interval=[^m]{1,4}ms\\] testmsg"
             + "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=1, interval=[^m]{1,4}ms\\] testmsg"

+ 2 - 22
hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZones.java

@@ -44,9 +44,7 @@ import org.apache.hadoop.crypto.CipherSuite;
 import org.apache.hadoop.crypto.CryptoProtocolVersion;
 import org.apache.hadoop.crypto.key.JavaKeyStoreProvider;
 import org.apache.hadoop.crypto.key.KeyProvider;
-import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
 import org.apache.hadoop.crypto.key.KeyProviderFactory;
-import org.apache.hadoop.crypto.key.kms.server.EagerKeyGeneratorKeyProviderCryptoExtension;
 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
 import org.apache.hadoop.fs.CreateFlag;
 import org.apache.hadoop.fs.FSDataOutputStream;
@@ -730,33 +728,15 @@ public class TestEncryptionZones {
     // Roll the key of the encryption zone
     assertNumZones(1);
     String keyName = dfsAdmin.listEncryptionZones().next().getKeyName();
-    FileEncryptionInfo feInfo1 = getFileEncryptionInfo(encFile1);
     cluster.getNamesystem().getProvider().rollNewVersion(keyName);
-    /**
-     * due to the cache on the server side, client may get old keys.
-     * @see EagerKeyGeneratorKeyProviderCryptoExtension#rollNewVersion(String)
-     */
-    boolean rollSucceeded = false;
-    for (int i = 0; i <= EagerKeyGeneratorKeyProviderCryptoExtension
-        .KMS_KEY_CACHE_SIZE_DEFAULT + CommonConfigurationKeysPublic.
-        KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT; ++i) {
-      KeyProviderCryptoExtension.EncryptedKeyVersion ekv2 =
-          cluster.getNamesystem().getProvider().generateEncryptedKey(TEST_KEY);
-      if (!(feInfo1.getEzKeyVersionName()
-          .equals(ekv2.getEncryptionKeyVersionName()))) {
-        rollSucceeded = true;
-        break;
-      }
-    }
-    Assert.assertTrue("rollover did not generate a new key even after"
-        + " queue is drained", rollSucceeded);
-
+    cluster.getNamesystem().getProvider().invalidateCache(keyName);
     // Read them back in and compare byte-by-byte
     verifyFilesEqual(fs, baseFile, encFile1, len);
     // Write a new enc file and validate
     final Path encFile2 = new Path(zone, "myfile2");
     DFSTestUtil.createFile(fs, encFile2, len, (short) 1, 0xFEED);
     // FEInfos should be different
+    FileEncryptionInfo feInfo1 = getFileEncryptionInfo(encFile1);
     FileEncryptionInfo feInfo2 = getFileEncryptionInfo(encFile2);
     assertFalse("EDEKs should be different", Arrays
         .equals(feInfo1.getEncryptedDataEncryptionKey(),