瀏覽代碼

修复Lambda引发的ClassNotFoundException
* 修复反序列化类加载失败.

聂秋秋 4 年之前
父節點
當前提交
cc6a2059a3

+ 38 - 5
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/ClassUtils.java

@@ -33,6 +33,16 @@ import java.util.List;
  */
  */
 public final class ClassUtils {
 public final class ClassUtils {
 
 
+    private static ClassLoader systemClassLoader;
+
+    static {
+        try {
+            systemClassLoader = ClassLoader.getSystemClassLoader();
+        } catch (SecurityException ignored) {
+            // AccessControlException on Google App Engine
+        }
+    }
+
     private static final char PACKAGE_SEPARATOR = '.';
     private static final char PACKAGE_SEPARATOR = '.';
 
 
     /**
     /**
@@ -146,15 +156,28 @@ public final class ClassUtils {
      * @return 返回转换后的 Class
      * @return 返回转换后的 Class
      */
      */
     public static Class<?> toClassConfident(String name) {
     public static Class<?> toClassConfident(String name) {
+        return toClassConfident(name, null);
+    }
+
+    public static Class<?> toClassConfident(String name, ClassLoader classLoader) {
         try {
         try {
-            return Resources.classForName(name);
+            return loadClass(name, getClassLoaders(classLoader));
         } catch (ClassNotFoundException e) {
         } catch (ClassNotFoundException e) {
-            try {
-                return Class.forName(name);
-            } catch (ClassNotFoundException ex) {
-                throw ExceptionUtils.mpe("找不到指定的class!请仅在明确确定会有 class 的时候,调用该方法", e);
+            throw ExceptionUtils.mpe("找不到指定的class!请仅在明确确定会有 class 的时候,调用该方法", e);
+        }
+    }
+
+    private static Class<?> loadClass(String className, ClassLoader[] classLoaders) throws ClassNotFoundException {
+        for (ClassLoader classLoader : classLoaders) {
+            if (classLoader != null) {
+                try {
+                    return Class.forName(className, true, classLoader);
+                } catch (ClassNotFoundException e) {
+                    // ignore
+                }
             }
             }
         }
         }
+        throw new ClassNotFoundException("Cannot find class: " + className);
     }
     }
 
 
 
 
@@ -201,6 +224,7 @@ public final class ClassUtils {
      * @see ClassLoader#getSystemClassLoader()
      * @see ClassLoader#getSystemClassLoader()
      * @since 3.3.2
      * @since 3.3.2
      */
      */
+    @Deprecated
     public static ClassLoader getDefaultClassLoader() {
     public static ClassLoader getDefaultClassLoader() {
         ClassLoader cl = null;
         ClassLoader cl = null;
         try {
         try {
@@ -222,4 +246,13 @@ public final class ClassUtils {
         }
         }
         return cl;
         return cl;
     }
     }
+
+    private static ClassLoader[] getClassLoaders(ClassLoader classLoader) {
+        return new ClassLoader[]{
+            classLoader,
+            Resources.getDefaultClassLoader(),
+            Thread.currentThread().getContextClassLoader(),
+            ClassUtils.class.getClassLoader(),
+            systemClassLoader};
+    }
 }
 }

+ 3 - 3
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/LambdaUtils.java

@@ -47,13 +47,13 @@ public final class LambdaUtils {
     private static final Map<String, WeakReference<SerializedLambda>> FUNC_CACHE = new ConcurrentHashMap<>();
     private static final Map<String, WeakReference<SerializedLambda>> FUNC_CACHE = new ConcurrentHashMap<>();
 
 
     /**
     /**
-     * 解析 lambda 表达式, 该方法只是调用了 {@link SerializedLambda#resolve(SFunction)} 中的方法,在此基础上加了缓存。
+     * 解析 lambda 表达式, 该方法只是调用了 {@link SerializedLambda#resolve(SFunction, ClassLoader)} 中的方法,在此基础上加了缓存。
      * 该缓存可能会在任意不定的时间被清除
      * 该缓存可能会在任意不定的时间被清除
      *
      *
      * @param func 需要解析的 lambda 对象
      * @param func 需要解析的 lambda 对象
      * @param <T>  类型,被调用的 Function 对象的目标类型
      * @param <T>  类型,被调用的 Function 对象的目标类型
      * @return 返回解析后的结果
      * @return 返回解析后的结果
-     * @see SerializedLambda#resolve(SFunction)
+     * @see SerializedLambda#resolve(SFunction, ClassLoader)
      */
      */
     public static <T> SerializedLambda resolve(SFunction<T, ?> func) {
     public static <T> SerializedLambda resolve(SFunction<T, ?> func) {
         Class<?> clazz = func.getClass();
         Class<?> clazz = func.getClass();
@@ -61,7 +61,7 @@ public final class LambdaUtils {
         return Optional.ofNullable(FUNC_CACHE.get(name))
         return Optional.ofNullable(FUNC_CACHE.get(name))
                 .map(WeakReference::get)
                 .map(WeakReference::get)
                 .orElseGet(() -> {
                 .orElseGet(() -> {
-                    SerializedLambda lambda = SerializedLambda.resolve(func);
+                    SerializedLambda lambda = SerializedLambda.resolve(func, clazz.getClassLoader());
                     FUNC_CACHE.put(name, new WeakReference<>(lambda));
                     FUNC_CACHE.put(name, new WeakReference<>(lambda));
                     return lambda;
                     return lambda;
                 });
                 });

+ 9 - 3
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/support/SerializedLambda.java

@@ -50,8 +50,14 @@ public class SerializedLambda implements Serializable {
      *
      *
      * @param lambda lambda对象
      * @param lambda lambda对象
      * @return 返回解析后的 SerializedLambda
      * @return 返回解析后的 SerializedLambda
+     * @deprecated 3.4.2 {@link #resolve(SFunction, ClassLoader)}
      */
      */
+    @Deprecated
     public static SerializedLambda resolve(SFunction<?, ?> lambda) {
     public static SerializedLambda resolve(SFunction<?, ?> lambda) {
+        return resolve(lambda, null);
+    }
+
+    public static SerializedLambda resolve(SFunction<?, ?> lambda, ClassLoader classLoader) {
         if (!lambda.getClass().isSynthetic()) {
         if (!lambda.getClass().isSynthetic()) {
             throw ExceptionUtils.mpe("该方法仅能传入 lambda 表达式产生的合成类");
             throw ExceptionUtils.mpe("该方法仅能传入 lambda 表达式产生的合成类");
         }
         }
@@ -60,7 +66,7 @@ public class SerializedLambda implements Serializable {
             protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
             protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
                 Class<?> clazz;
                 Class<?> clazz;
                 try {
                 try {
-                    clazz = ClassUtils.toClassConfident(objectStreamClass.getName());
+                    clazz = ClassUtils.toClassConfident(objectStreamClass.getName(), classLoader);
                 } catch (Exception ex) {
                 } catch (Exception ex) {
                     clazz = super.resolveClass(objectStreamClass);
                     clazz = super.resolveClass(objectStreamClass);
                 }
                 }
@@ -88,7 +94,7 @@ public class SerializedLambda implements Serializable {
      * @return 实现类
      * @return 实现类
      */
      */
     public Class<?> getImplClass() {
     public Class<?> getImplClass() {
-        return ClassUtils.toClassConfident(getImplClassName());
+        return ClassUtils.toClassConfident(getImplClassName(), this.capturingClass.getClassLoader());
     }
     }
 
 
     /**
     /**
@@ -124,7 +130,7 @@ public class SerializedLambda implements Serializable {
      */
      */
     public Class<?> getInstantiatedType() {
     public Class<?> getInstantiatedType() {
         String instantiatedTypeName = normalizedName(instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';')));
         String instantiatedTypeName = normalizedName(instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';')));
-        return ClassUtils.toClassConfident(instantiatedTypeName);
+        return ClassUtils.toClassConfident(instantiatedTypeName, this.capturingClass.getClassLoader());
     }
     }
 
 
     /**
     /**