Browse Source

pref: 获取 lambda 信息不再使用序列化 https://github.com/baomidou/mybatis-plus/pull/3517

hanchunlin 4 years ago
parent
commit
3b1741e40e

+ 6 - 10
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractLambdaWrapper.java

@@ -20,9 +20,9 @@ import com.baomidou.mybatisplus.core.toolkit.LambdaUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
-import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda;
 import org.apache.ibatis.reflection.property.PropertyNamer;
 
+import java.lang.invoke.SerializedLambda;
 import java.util.Arrays;
 import java.util.Map;
 
@@ -37,7 +37,7 @@ import static java.util.stream.Collectors.joining;
  */
 @SuppressWarnings("serial")
 public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWrapper<T, Children>>
-    extends AbstractWrapper<T, SFunction<T, ?>, Children> {
+        extends AbstractWrapper<T, SFunction<T, ?>, Children> {
 
     private Map<String, ColumnCache> columnMap = null;
     private boolean initColumnMap = false;
@@ -70,15 +70,11 @@ public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWr
      *
      * @return 列
      * @throws com.baomidou.mybatisplus.core.exceptions.MybatisPlusException 获取不到列信息时抛出异常
-     * @see SerializedLambda#getImplClass()
-     * @see SerializedLambda#getImplMethodName()
      */
     protected ColumnCache getColumnCache(SFunction<T, ?> column) {
-        SerializedLambda lambda = LambdaUtils.resolve(column);
-        Class<?> aClass = lambda.getInstantiatedType();
-        tryInitCache(aClass);
-        String fieldName = PropertyNamer.methodToProperty(lambda.getImplMethodName());
-        return getColumnCache(fieldName, aClass);
+        SerializedLambda lambda = LambdaUtils.extract(column);
+        String fileName = PropertyNamer.methodToProperty(lambda.getImplMethodName());
+        return getColumnCache(fileName, LambdaUtils.instantiatedClass(lambda));
     }
 
     private void tryInitCache(Class<?> lambdaClass) {
@@ -96,7 +92,7 @@ public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWr
     private ColumnCache getColumnCache(String fieldName, Class<?> lambdaClass) {
         ColumnCache columnCache = columnMap.get(LambdaUtils.formatKey(fieldName));
         Assert.notNull(columnCache, "can not find lambda cache for this property [%s] of entity [%s]",
-            fieldName, lambdaClass.getName());
+                fieldName, lambdaClass.getName());
         return columnCache;
     }
 }

+ 52 - 21
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/LambdaUtils.java

@@ -15,15 +15,17 @@
  */
 package com.baomidou.mybatisplus.core.toolkit;
 
+import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
-import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda;
 
-import java.lang.ref.WeakReference;
+import java.lang.invoke.SerializedLambda;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
 import java.util.Map;
-import java.util.Optional;
 import java.util.concurrent.ConcurrentHashMap;
 
 import static java.util.Locale.ENGLISH;
@@ -35,6 +37,16 @@ import static java.util.Locale.ENGLISH;
  * @since 2018-05-10
  */
 public final class LambdaUtils {
+    private static final Field FIELD_CAPTURING_CLASS;
+
+    static {
+        try {
+            Class<SerializedLambda> aClass = SerializedLambda.class;
+            FIELD_CAPTURING_CLASS = ReflectionKit.setAccessible(aClass.getDeclaredField("capturingClass"));
+        } catch (NoSuchFieldException e) {
+            throw new MybatisPlusException(e);
+        }
+    }
 
     /**
      * 字段映射
@@ -42,29 +54,48 @@ public final class LambdaUtils {
     private static final Map<String, Map<String, ColumnCache>> COLUMN_CACHE_MAP = new ConcurrentHashMap<>();
 
     /**
-     * SerializedLambda 反序列化缓存
-     */
-    private static final Map<String, WeakReference<SerializedLambda>> FUNC_CACHE = new ConcurrentHashMap<>();
-
-    /**
-     * 解析 lambda 表达式, 该方法只是调用了 {@link SerializedLambda#resolve(SFunction, ClassLoader)} 中的方法,在此基础上加了缓存。
      * 该缓存可能会在任意不定的时间被清除
      *
      * @param func 需要解析的 lambda 对象
      * @param <T>  类型,被调用的 Function 对象的目标类型
      * @return 返回解析后的结果
-     * @see SerializedLambda#resolve(SFunction, ClassLoader)
      */
-    public static <T> SerializedLambda resolve(SFunction<T, ?> func) {
-        Class<?> clazz = func.getClass();
-        String name = clazz.getName();
-        return Optional.ofNullable(FUNC_CACHE.get(name))
-            .map(WeakReference::get)
-            .orElseGet(() -> {
-                SerializedLambda lambda = SerializedLambda.resolve(func, clazz.getClassLoader());
-                FUNC_CACHE.put(name, new WeakReference<>(lambda));
-                return lambda;
-            });
+    public static <T> SerializedLambda extract(SFunction<T, ?> func) {
+        try {
+            Method method = func.getClass().getDeclaredMethod("writeReplace");
+            return (SerializedLambda) ReflectionKit.setAccessible(method).invoke(func);
+        } catch (NoSuchMethodException e) {
+            String message = "Cannot find method writeReplace, please make sure that the lambda composite class is currently passed in";
+            throw new MybatisPlusException(message);
+        } catch (InvocationTargetException | IllegalAccessException e) {
+            throw new MybatisPlusException(e);
+        }
+    }
+
+    /**
+     * 实例化该接口的类名
+     *
+     * @param lambda lambda 对象
+     * @return 返回对应的实例类
+     */
+    public static Class<?> instantiatedClass(SerializedLambda lambda) {
+        String instantiatedMethodType = lambda.getInstantiatedMethodType();
+        String instantiatedType = instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';')).replace('/', '.');
+        return ClassUtils.toClassConfident(instantiatedType, capturingClass(lambda).getClassLoader());
+    }
+
+    /**
+     * 获取 lambda 的捕获类,这取决于 lambda 类在构造时所处的类
+     *
+     * @param lambda lambda
+     * @return 返回对应的捕获类
+     */
+    public static Class<?> capturingClass(SerializedLambda lambda) {
+        try {
+            return (Class<?>) FIELD_CAPTURING_CLASS.get(lambda);
+        } catch (IllegalAccessException e) {
+            throw new MybatisPlusException(e);
+        }
     }
 
     /**
@@ -107,7 +138,7 @@ public final class LambdaUtils {
         }
 
         info.getFieldList().forEach(i ->
-            map.put(formatKey(i.getProperty()), new ColumnCache(i.getColumn(), i.getSqlSelect(), i.getMapping()))
+                map.put(formatKey(i.getProperty()), new ColumnCache(i.getColumn(), i.getSqlSelect(), i.getMapping()))
         );
         return map;
     }

+ 26 - 16
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/ReflectionKit.java

@@ -18,10 +18,8 @@ package com.baomidou.mybatisplus.core.toolkit;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
 
-import java.lang.reflect.Field;
-import java.lang.reflect.Modifier;
-import java.lang.reflect.ParameterizedType;
-import java.lang.reflect.Type;
+import java.lang.reflect.*;
+import java.security.AccessController;
 import java.util.*;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
@@ -100,12 +98,12 @@ public final class ReflectionKit {
         Type[] params = ((ParameterizedType) genType).getActualTypeArguments();
         if (index >= params.length || index < 0) {
             logger.warn(String.format("Warn: Index: %s, Size of %s's Parameterized Type: %s .", index,
-                clazz.getSimpleName(), params.length));
+                    clazz.getSimpleName(), params.length));
             return Object.class;
         }
         if (!(params[index] instanceof Class)) {
             logger.warn(String.format("Warn: %s not set the actual class on superclass generic parameter",
-                clazz.getSimpleName()));
+                    clazz.getSimpleName()));
             return Object.class;
         }
         return (Class<?>) params[index];
@@ -151,11 +149,11 @@ public final class ReflectionKit {
              * 中间表实体重写父类属性 ` private transient Date createTime; `
              */
             return fieldMap.values().stream()
-                /* 过滤静态属性 */
-                .filter(f -> !Modifier.isStatic(f.getModifiers()))
-                /* 过滤 transient关键字修饰的属性 */
-                .filter(f -> !Modifier.isTransient(f.getModifiers()))
-                .collect(Collectors.toList());
+                    /* 过滤静态属性 */
+                    .filter(f -> !Modifier.isStatic(f.getModifiers()))
+                    /* 过滤 transient关键字修饰的属性 */
+                    .filter(f -> !Modifier.isTransient(f.getModifiers()))
+                    .collect(Collectors.toList());
         });
     }
 
@@ -170,12 +168,12 @@ public final class ReflectionKit {
     public static Map<String, Field> excludeOverrideSuperField(Field[] fields, List<Field> superFieldList) {
         // 子类属性
         Map<String, Field> fieldMap = Stream.of(fields).collect(toMap(Field::getName, identity(),
-            (u, v) -> {
-                throw new IllegalStateException(String.format("Duplicate key %s", u));
-            },
-            LinkedHashMap::new));
+                (u, v) -> {
+                    throw new IllegalStateException(String.format("Duplicate key %s", u));
+                },
+                LinkedHashMap::new));
         superFieldList.stream().filter(field -> !fieldMap.containsKey(field.getName()))
-            .forEach(f -> fieldMap.put(f.getName(), f));
+                .forEach(f -> fieldMap.put(f.getName(), f));
         return fieldMap;
     }
 
@@ -193,4 +191,16 @@ public final class ReflectionKit {
     public static Class<?> resolvePrimitiveIfNecessary(Class<?> clazz) {
         return (clazz.isPrimitive() && clazz != void.class ? PRIMITIVE_TYPE_TO_WRAPPER_MAP.get(clazz) : clazz);
     }
+
+    /**
+     * 设置可访问对象的可访问权限为 true
+     *
+     * @param object 可访问的对象
+     * @param <T>    类型
+     * @return 返回设置后的对象
+     */
+    public static <T extends AccessibleObject> T setAccessible(T object) {
+        return AccessController.doPrivileged(new SetAccessibleAction<>(object));
+    }
+
 }

+ 22 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/SetAccessibleAction.java

@@ -0,0 +1,22 @@
+package com.baomidou.mybatisplus.core.toolkit;
+
+import java.lang.reflect.AccessibleObject;
+import java.security.PrivilegedAction;
+
+/**
+ * Create by hcl at 2021/5/14
+ */
+public class SetAccessibleAction<T extends AccessibleObject> implements PrivilegedAction<T> {
+    private final T obj;
+
+    public SetAccessibleAction(T obj) {
+        this.obj = obj;
+    }
+
+    @Override
+    public T run() {
+        obj.setAccessible(true);
+        return obj;
+    }
+
+}

+ 0 - 143
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/support/SerializedLambda.java

@@ -1,143 +0,0 @@
-/*
- * Copyright (c) 2011-2021, baomidou (jobob@qq.com).
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.baomidou.mybatisplus.core.toolkit.support;
-
-import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
-import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
-import com.baomidou.mybatisplus.core.toolkit.SerializationUtils;
-
-import java.io.*;
-
-/**
- * 这个类是从 {@link java.lang.invoke.SerializedLambda} 里面 copy 过来的,
- * 字段信息完全一样
- * <p>负责将一个支持序列的 Function 序列化为 SerializedLambda</p>
- *
- * @author HCL
- * @since 2018/05/10
- */
-@SuppressWarnings("unused")
-public class SerializedLambda implements Serializable {
-
-    private static final long serialVersionUID = 8025925345765570181L;
-
-    private Class<?> capturingClass;
-    private String functionalInterfaceClass;
-    private String functionalInterfaceMethodName;
-    private String functionalInterfaceMethodSignature;
-    private String implClass;
-    private String implMethodName;
-    private String implMethodSignature;
-    private int implMethodKind;
-    private String instantiatedMethodType;
-    private Object[] capturedArgs;
-
-    /**
-     * 通过反序列化转换 lambda 表达式,该方法只能序列化 lambda 表达式,不能序列化接口实现或者正常非 lambda 写法的对象
-     *
-     * @param lambda lambda对象
-     * @return 返回解析后的 SerializedLambda
-     */
-    public static SerializedLambda resolve(SFunction<?, ?> lambda, ClassLoader classLoader) {
-        if (!lambda.getClass().isSynthetic()) {
-            throw ExceptionUtils.mpe("该方法仅能传入 lambda 表达式产生的合成类");
-        }
-        try (ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(SerializationUtils.serialize(lambda))) {
-            @Override
-            protected Class<?> resolveClass(ObjectStreamClass objectStreamClass) throws IOException, ClassNotFoundException {
-                Class<?> clazz;
-                try {
-                    clazz = ClassUtils.toClassConfident(objectStreamClass.getName(), classLoader);
-                } catch (Exception ex) {
-                    clazz = super.resolveClass(objectStreamClass);
-                }
-                return clazz == java.lang.invoke.SerializedLambda.class ? SerializedLambda.class : clazz;
-            }
-        }) {
-            return (SerializedLambda) objIn.readObject();
-        } catch (ClassNotFoundException | IOException e) {
-            throw ExceptionUtils.mpe("This is impossible to happen", e);
-        }
-    }
-
-    /**
-     * 获取接口 class
-     *
-     * @return 返回 class 名称
-     */
-    public String getFunctionalInterfaceClassName() {
-        return normalizedName(functionalInterfaceClass);
-    }
-
-    /**
-     * 获取实现的 class
-     *
-     * @return 实现类
-     */
-    public Class<?> getImplClass() {
-        return ClassUtils.toClassConfident(getImplClassName(), this.capturingClass.getClassLoader());
-    }
-
-    /**
-     * 获取 class 的名称
-     *
-     * @return 类名
-     */
-    public String getImplClassName() {
-        return normalizedName(implClass);
-    }
-
-    /**
-     * 获取实现者的方法名称
-     *
-     * @return 方法名称
-     */
-    public String getImplMethodName() {
-        return implMethodName;
-    }
-
-    /**
-     * 正常化类名称,将类名称中的 / 替换为 .
-     *
-     * @param name 名称
-     * @return 正常的类名
-     */
-    private String normalizedName(String name) {
-        return name.replace('/', '.');
-    }
-
-    /**
-     * @return 获取实例化方法的类型
-     */
-    public Class<?> getInstantiatedType() {
-        String instantiatedTypeName = normalizedName(instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(';')));
-        return ClassUtils.toClassConfident(instantiatedTypeName, this.capturingClass.getClassLoader());
-    }
-
-    /**
-     * @return 字符串形式
-     */
-    @Override
-    public String toString() {
-        String interfaceName = getFunctionalInterfaceClassName();
-        String implName = getImplClassName();
-        return String.format("%s -> %s::%s",
-            interfaceName.substring(interfaceName.lastIndexOf('.') + 1),
-            implName.substring(implName.lastIndexOf('.') + 1),
-            implMethodName);
-    }
-
-}

+ 11 - 34
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/test/toolkit/LambdaUtilsTest.java

@@ -16,13 +16,14 @@
 package com.baomidou.mybatisplus.test.toolkit;
 
 import com.baomidou.mybatisplus.core.toolkit.LambdaUtils;
-import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda;
+import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import lombok.Getter;
-import org.apache.ibatis.reflection.property.PropertyNamer;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import static org.junit.jupiter.api.Assertions.assertEquals;
+import java.lang.invoke.SerializedLambda;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertSame;
 
 /**
  * 测试 Lambda 解析类
@@ -33,35 +34,11 @@ class LambdaUtilsTest {
      * 测试解析
      */
     @Test
-    void testResolve() {
-        SerializedLambda lambda = LambdaUtils.resolve(TestModel::getId);
-        assertEquals(Parent.class.getName(), lambda.getImplClassName());
-        assertEquals("getId", lambda.getImplMethodName());
-        assertEquals("id", PropertyNamer.methodToProperty(lambda.getImplMethodName()));
-        assertEquals(TestModel.class, lambda.getInstantiatedType());
-
-        // 测试接口泛型获取
-        lambda = new TestModelHolder().toLambda();
-        // 无法从泛型获取到实现类,即使改泛型参数已经被实现
-        assertEquals(Named.class, lambda.getInstantiatedType());
-    }
-
-    /**
-     * 在 Java 中,一般来讲,只要是泛型,肯定是引用类型,但是为了避免翻车,还是测试一下
-     */
-    @Test
-    void test() {
-        assertInstantiatedMethodTypeIsReference(LambdaUtils.resolve(TestModel::getId));
-        assertInstantiatedMethodTypeIsReference(LambdaUtils.resolve(Integer::byteValue));
-    }
-
-    /**
-     * 断言当前方法所在实例的方法类型为引用类型
-     *
-     * @param lambda 解析后的 lambda
-     */
-    private void assertInstantiatedMethodTypeIsReference(SerializedLambda lambda) {
-        Assertions.assertNotNull(lambda.getInstantiatedType());
+    void testExtract() {
+        SFunction<TestModel, Object> getId = TestModel::getId;
+        SerializedLambda lambda = LambdaUtils.extract(getId);
+        assertNotNull(lambda);
+        assertSame(TestModel.class, LambdaUtils.instantiatedClass(lambda));
     }
 
     /**
@@ -81,7 +58,7 @@ class LambdaUtilsTest {
     private abstract static class BaseHolder<T extends Named> {
 
         SerializedLambda toLambda() {
-            return LambdaUtils.resolve(T::getName);
+            return LambdaUtils.extract(T::getName);
         }
 
     }

+ 1 - 1
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -498,7 +498,7 @@ class H2UserTest extends BaseTest {
     }
 
     /**
-     * 观察 {@link com.baomidou.mybatisplus.core.toolkit.LambdaUtils#resolve(SFunction)}
+     * 观察 {@link com.baomidou.mybatisplus.core.toolkit.LambdaUtils#extract(SFunction)}
      */
     @Test
     void testLambdaCache() {