Caratacus 9 年之前
父节点
当前提交
b957dd445c

+ 11 - 18
mybatis-plus/src/main/java/com/baomidou/framework/service/impl/ServiceImpl.java

@@ -15,14 +15,6 @@
  */
 package com.baomidou.framework.service.impl;
 
-import java.io.Serializable;
-import java.lang.reflect.Method;
-import java.util.List;
-import java.util.Map;
-
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.transaction.annotation.Transactional;
-
 import com.baomidou.framework.service.IService;
 import com.baomidou.mybatisplus.exceptions.MybatisPlusException;
 import com.baomidou.mybatisplus.mapper.BaseMapper;
@@ -31,6 +23,12 @@ import com.baomidou.mybatisplus.plugins.Page;
 import com.baomidou.mybatisplus.toolkit.ReflectionKit;
 import com.baomidou.mybatisplus.toolkit.TableInfo;
 import com.baomidou.mybatisplus.toolkit.TableInfoHelper;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.transaction.annotation.Transactional;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
 
 /**
  * <p>
@@ -73,16 +71,11 @@ public class ServiceImpl<M extends BaseMapper<T, PK>, T, PK extends Serializable
 			Class<?> cls = entity.getClass();
 			TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
 			if (null != tableInfo) {
-				try {
-					Method m = cls.getMethod(ReflectionKit.getMethodCapitalize(tableInfo.getKeyProperty()));
-					Object idVal = m.invoke(entity);
-					if (null != idVal) {
-						return isSelective ? updateSelectiveById(entity) : updateById(entity);
-					} else {
-						return isSelective ? insertSelective(entity) : insert(entity);
-					}
-				} catch (Exception e) {
-					e.printStackTrace();
+				Object idVal = ReflectionKit.getMethodValue(cls, entity, tableInfo.getKeyProperty());
+				if (null != idVal) {
+					return isSelective ? updateSelectiveById(entity) : updateById(entity);
+				} else {
+					return isSelective ? insertSelective(entity) : insert(entity);
 				}
 			} else {
 				throw new MybatisPlusException("Error:  Cannot execute. Could not find @TableId.");

+ 78 - 76
mybatis-plus/src/main/java/com/baomidou/mybatisplus/toolkit/ReflectionKit.java

@@ -1,12 +1,12 @@
 /**
  * Copyright (c) 2011-2020, hubin (jobob@qq.com).
- *
+ * <p>
  * Licensed under the Apache License, Version 2.0 (the "License"); you may not
  * use this file except in compliance with the License. You may obtain a copy of
  * the License at
- *
+ * <p>
  * http://www.apache.org/licenses/LICENSE-2.0
- *
+ * <p>
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
@@ -15,12 +15,13 @@
  */
 package com.baomidou.mybatisplus.toolkit;
 
-import java.lang.reflect.Method;
-import java.util.List;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.List;
+
 /**
  * <p>
  * 反射工具类
@@ -31,79 +32,80 @@ import org.slf4j.LoggerFactory;
  */
 public class ReflectionKit {
 
-	private static Logger logger = LoggerFactory.getLogger(ReflectionKit.class);
-
-	/**
-	 * <p>
-	 * 反射 method 方法名,例如 getId
-	 * </p>
-	 *
-	 * @param str
-	 *            属性字符串内容
-	 * @return
-	 */
-	public static String getMethodCapitalize(final String str) {
-		return StringUtils.concatCapitalize("get", str);
-	}
-
-	/**
-	 * 调用对象的get方法检查对象所有属性是否为null
-	 * 
-	 * @param bean
-	 *            检查对象
-	 * @return boolean true对象所有属性不为null,false对象所有属性为null
-	 */
-	public static boolean checkFieldValueNotNull(Object bean) {
-		if (null == bean) {
-			return false;
-		}
+    private static Logger logger = LoggerFactory.getLogger(ReflectionKit.class);
 
-		Class<?> cls = bean.getClass();
-		Method[] methods = cls.getDeclaredMethods();
-		TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
-		if (null == tableInfo) {
-			logger.warn("Warn: Could not find @TableId.");
-			return false;
-		}
+    /**
+     * <p>
+     * 反射 method 方法名,例如 getId
+     * </p>
+     *
+     * @param str 属性字符串内容
+     * @return
+     */
+    public static String getMethodCapitalize(final String str) {
+        return StringUtils.concatCapitalize("get", str);
+    }
 
-		boolean result = false;
-		List<TableFieldInfo> fieldList = tableInfo.getFieldList();
-		for (TableFieldInfo tableFieldInfo : fieldList) {
-			String fieldGetName = getMethodCapitalize(tableFieldInfo.getProperty());
-			if (!checkMethod(methods, fieldGetName)) {
-				continue;
-			}
-			try {
-				Method method = cls.getMethod(fieldGetName);
-				Object obj = method.invoke(bean);
-				if (null != obj) {
-					result = true;
-					break;
-				}
-			} catch (Exception e) {
-				logger.warn("Warn: Unexpected exception on checkFieldValueNull.  Cause:" + e);
-			}
+    /**
+     * 获取 public get方法的值
+     *
+     * @param cls
+     * @param entity 实体
+     * @param str    属性字符串内容
+     * @return Object
+     */
+    public static Object getMethodValue(Class cls, Object entity, String str) {
+        Object obj = null;
+        try {
+            Method method = cls.getMethod(getMethodCapitalize(str));
+            obj = method.invoke(entity);
+        } catch (NoSuchMethodException e) {
+            logger.warn("Warn: No such method. in " + cls);
+        } catch (IllegalAccessException e) {
+            logger.warn("Warn: Cannot execute a private method. in " + cls);
+        } catch (InvocationTargetException e) {
+            logger.warn("Warn: Unexpected exception on getMethodValue.  Cause:" + e);
+        }
+        return obj;
+    }
 
-		}
-		return result;
-	}
+    /**
+     * 获取 public get方法的值
+     *
+     * @param entity 实体
+     * @param str    属性字符串内容
+     * @return Object
+     */
+    public static Object getMethodValue(Object entity, String str) {
+        return getMethodValue(entity.getClass(), entity, str);
+    }
 
-	/**
-	 * 判断是否存在某属性的 get方法
-	 *
-	 * @param methods
-	 *            对象所有方法
-	 * @param method
-	 *            当前检查的方法
-	 * @return boolean true存在,false不存在
-	 */
-	public static boolean checkMethod(Method[] methods, String method) {
-		for (Method met : methods) {
-			if (method.equals(met.getName())) {
-				return true;
-			}
-		}
-		return false;
-	}
+    /**
+     * 调用对象的get方法检查对象所有属性是否为null
+     *
+     * @param bean 检查对象
+     * @return boolean true对象所有属性不为null,false对象所有属性为null
+     */
+    public static boolean checkFieldValueNotNull(Object bean) {
+        if (null == bean) {
+            return false;
+        }
+        Class<?> cls = bean.getClass();
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
+        if (null == tableInfo) {
+            logger.warn("Warn: Could not find @TableId.");
+            return false;
+        }
+        boolean result = false;
+        List<TableFieldInfo> fieldList = tableInfo.getFieldList();
+        for (TableFieldInfo tableFieldInfo : fieldList) {
+            Object val = getMethodValue(cls, bean, tableFieldInfo.getProperty());
+            if (null != val) {
+                result = true;
+                break;
+            }
+        }
+        return result;
+    }
 
 }