Quellcode durchsuchen

开发多租户逻辑推一波

hubin vor 7 Jahren
Ursprung
Commit
57931347c9

+ 3 - 3
src/main/java/com/baomidou/mybatisplus/plugins/CachePaginationInterceptor.java

@@ -35,7 +35,7 @@ import org.apache.ibatis.session.RowBounds;
 import com.baomidou.mybatisplus.enums.DBType;
 import com.baomidou.mybatisplus.enums.DBType;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
@@ -57,7 +57,7 @@ public class CachePaginationInterceptor extends PaginationInterceptor implements
     /* 溢出总页数,设置第一页 */
     /* 溢出总页数,设置第一页 */
     private boolean overflowCurrent = false;
     private boolean overflowCurrent = false;
     // COUNT SQL 解析
     // COUNT SQL 解析
-    private AbstractSqlParser sqlParser;
+    private ISqlParser sqlParser;
     /* 方言类型 */
     /* 方言类型 */
     private String dialectType;
     private String dialectType;
     /* 方言实现类 */
     /* 方言实现类 */
@@ -153,7 +153,7 @@ public class CachePaginationInterceptor extends PaginationInterceptor implements
         return this;
         return this;
     }
     }
 
 
-    public CachePaginationInterceptor setSqlParser(AbstractSqlParser sqlParser) {
+    public CachePaginationInterceptor setSqlParser(ISqlParser sqlParser) {
         this.sqlParser = sqlParser;
         this.sqlParser = sqlParser;
         return this;
         return this;
     }
     }

+ 10 - 10
src/main/java/com/baomidou/mybatisplus/plugins/PaginationInterceptor.java

@@ -41,7 +41,7 @@ import com.baomidou.mybatisplus.enums.DBType;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.PageHelper;
 import com.baomidou.mybatisplus.plugins.pagination.PageHelper;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
@@ -62,7 +62,7 @@ public class PaginationInterceptor implements Interceptor {
     // 日志
     // 日志
     private static final Log logger = LogFactory.getLog(PaginationInterceptor.class);
     private static final Log logger = LogFactory.getLog(PaginationInterceptor.class);
     // COUNT SQL 解析
     // COUNT SQL 解析
-    private AbstractSqlParser sqlParser;
+    private ISqlParser sqlParser;
     /* 溢出总页数,设置第一页 */
     /* 溢出总页数,设置第一页 */
     private boolean overflowCurrent = false;
     private boolean overflowCurrent = false;
     /* 方言类型 */
     /* 方言类型 */
@@ -78,13 +78,13 @@ public class PaginationInterceptor implements Interceptor {
     @Override
     @Override
     public Object intercept(Invocation invocation) throws Throwable {
     public Object intercept(Invocation invocation) throws Throwable {
         StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
         StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
-        MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
+        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
         // 先判断是不是SELECT操作
         // 先判断是不是SELECT操作
-        MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
+        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
         if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
         if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
             return invocation.proceed();
             return invocation.proceed();
         }
         }
-        RowBounds rowBounds = (RowBounds) metaStatementHandler.getValue("delegate.rowBounds");
+        RowBounds rowBounds = (RowBounds) metaObject.getValue("delegate.rowBounds");
         /* 不需要分页的场合 */
         /* 不需要分页的场合 */
         if (rowBounds == null || rowBounds == RowBounds.DEFAULT) {
         if (rowBounds == null || rowBounds == RowBounds.DEFAULT) {
             // 本地线程分页
             // 本地线程分页
@@ -100,7 +100,7 @@ public class PaginationInterceptor implements Interceptor {
             }
             }
         }
         }
         // 针对定义了rowBounds,做为mapper接口方法的参数
         // 针对定义了rowBounds,做为mapper接口方法的参数
-        BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
+        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
         String originalSql = boundSql.getSql();
         String originalSql = boundSql.getSql();
         Connection connection = (Connection) invocation.getArgs()[0];
         Connection connection = (Connection) invocation.getArgs()[0];
         DBType dbType = StringUtils.isNotEmpty(dialectType) ? DBType.getDBType(dialectType) : JdbcUtils.getDbType(connection.getMetaData().getURL());
         DBType dbType = StringUtils.isNotEmpty(dialectType) ? DBType.getDBType(dialectType) : JdbcUtils.getDbType(connection.getMetaData().getURL());
@@ -126,9 +126,9 @@ public class PaginationInterceptor implements Interceptor {
          * <p> 禁用内存分页 </p>
          * <p> 禁用内存分页 </p>
          * <p> 内存分页会查询所有结果出来处理(这个很吓人的),如果结果变化频繁这个数据还会不准。</p>
          * <p> 内存分页会查询所有结果出来处理(这个很吓人的),如果结果变化频繁这个数据还会不准。</p>
 		 */
 		 */
-        metaStatementHandler.setValue("delegate.boundSql.sql", originalSql);
-        metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
-        metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
+        metaObject.setValue("delegate.boundSql.sql", originalSql);
+        metaObject.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
+        metaObject.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
         return invocation.proceed();
         return invocation.proceed();
     }
     }
 
 
@@ -200,7 +200,7 @@ public class PaginationInterceptor implements Interceptor {
         return this;
         return this;
     }
     }
 
 
-    public PaginationInterceptor setSqlParser(AbstractSqlParser sqlParser) {
+    public PaginationInterceptor setSqlParser(ISqlParser sqlParser) {
         this.sqlParser = sqlParser;
         this.sqlParser = sqlParser;
         return this;
         return this;
     }
     }

+ 5 - 5
src/main/java/com/baomidou/mybatisplus/plugins/SqlParserInterceptor.java

@@ -28,7 +28,7 @@ import org.apache.ibatis.plugin.Signature;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
 
 
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
@@ -46,7 +46,7 @@ public class SqlParserInterceptor implements Interceptor {
 
 
     private static final String DELEGATE_BOUNDSQL_SQL = "delegate.boundSql.sql";
     private static final String DELEGATE_BOUNDSQL_SQL = "delegate.boundSql.sql";
     // SQL 解析
     // SQL 解析
-    private List<AbstractSqlParser> sqlParserList;
+    private List<ISqlParser> sqlParserList;
 
 
     /**
     /**
      * 拦截 SQL 解析执行
      * 拦截 SQL 解析执行
@@ -59,7 +59,7 @@ public class SqlParserInterceptor implements Interceptor {
         if (CollectionUtils.isNotEmpty(sqlParserList)) {
         if (CollectionUtils.isNotEmpty(sqlParserList)) {
             int flag = 0;// 标记是否修改过 SQL
             int flag = 0;// 标记是否修改过 SQL
             String originalSql = (String) metaObject.getValue(DELEGATE_BOUNDSQL_SQL);
             String originalSql = (String) metaObject.getValue(DELEGATE_BOUNDSQL_SQL);
-            for (AbstractSqlParser sqlParser : sqlParserList) {
+            for (ISqlParser sqlParser : sqlParserList) {
                 SqlInfo sqlInfo = sqlParser.optimizeSql(metaObject, originalSql);
                 SqlInfo sqlInfo = sqlParser.optimizeSql(metaObject, originalSql);
                 if (null != sqlInfo) {
                 if (null != sqlInfo) {
                     originalSql = sqlInfo.getSql();
                     originalSql = sqlInfo.getSql();
@@ -86,11 +86,11 @@ public class SqlParserInterceptor implements Interceptor {
         // to do nothing
         // to do nothing
     }
     }
 
 
-    public List<AbstractSqlParser> getSqlParserList() {
+    public List<ISqlParser> getSqlParserList() {
         return sqlParserList;
         return sqlParserList;
     }
     }
 
 
-    public SqlParserInterceptor setSqlParserList(List<AbstractSqlParser> sqlParserList) {
+    public SqlParserInterceptor setSqlParserList(List<ISqlParser> sqlParserList) {
         this.sqlParserList = sqlParserList;
         this.sqlParserList = sqlParserList;
         return this;
         return this;
     }
     }

+ 7 - 3
src/main/java/com/baomidou/mybatisplus/plugins/pagination/optimize/JsqlParserCountOptimize.java

@@ -18,9 +18,11 @@ package com.baomidou.mybatisplus.plugins.pagination.optimize;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 
 
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.MetaObject;
 
 
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.toolkit.SqlUtils;
 import com.baomidou.mybatisplus.toolkit.SqlUtils;
@@ -45,8 +47,10 @@ import net.sf.jsqlparser.statement.select.SelectItem;
  * @author hubin
  * @author hubin
  * @since 2017-06-20
  * @since 2017-06-20
  */
  */
-public class JsqlParserCountOptimize extends AbstractSqlParser {
+public class JsqlParserCountOptimize implements ISqlParser {
 
 
+    // 日志
+    private final Log logger = LogFactory.getLog(JsqlParserCountOptimize.class);
     private static final List<SelectItem> countSelectItem = countSelectItem();
     private static final List<SelectItem> countSelectItem = countSelectItem();
 
 
     @Override
     @Override
@@ -69,7 +73,7 @@ public class JsqlParserCountOptimize extends AbstractSqlParser {
             }
             }
             //#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
             //#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
             for (SelectItem item : plainSelect.getSelectItems()) {
             for (SelectItem item : plainSelect.getSelectItems()) {
-                if(item.toString().contains("?")){
+                if (item.toString().contains("?")) {
                     sqlInfo.setSql(String.format(SqlUtils.SQL_BASE_COUNT, selectStatement.toString()));
                     sqlInfo.setSql(String.format(SqlUtils.SQL_BASE_COUNT, selectStatement.toString()));
                     return sqlInfo;
                     return sqlInfo;
                 }
                 }

+ 87 - 0
src/main/java/com/baomidou/mybatisplus/plugins/parser/AbstractJsqlParser.java

@@ -0,0 +1,87 @@
+/**
+ * Copyright (c) 2011-2020, hubin (jobob@qq.com).
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.baomidou.mybatisplus.plugins.parser;
+
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.reflection.MetaObject;
+
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.Statement;
+
+/**
+ * <p>
+ * 抽象 SQL 解析类
+ * </p>
+ *
+ * @author hubin
+ * @Date 2017-06-20
+ */
+public abstract class AbstractJsqlParser implements ISqlParser {
+
+    // 日志
+    protected final Log logger = LogFactory.getLog(this.getClass());
+
+    /**
+     * <p>
+     * 获取优化 SQL 方法
+     * </p>
+     *
+     * @param metaObject 元对象
+     * @param sql        SQL 语句
+     * @return SQL 信息
+     */
+
+    @Override
+    public SqlInfo optimizeSql(MetaObject metaObject, String sql) {
+        if (this.allowProcess(metaObject)) {
+            try {
+                Statement statement = CCJSqlParserUtil.parse(sql);
+                logger.debug("old sql: " + sql + ",statement: " + statement);
+                if (null != statement) {
+                    return this.processParser(statement);
+                }
+            } catch (JSQLParserException e) {
+                logger.error("解析sql: " + sql + ",异常: " + e.getMessage());
+            }
+        }
+        return null;
+    }
+
+    /**
+     * <p>
+     * 执行 SQL 解析
+     * </p>
+     *
+     * @param statement JsqlParser Statement
+     * @return
+     */
+    public abstract SqlInfo processParser(Statement statement);
+
+    /**
+     * <p>
+     * 判断是否允许执行<br>
+     * 例如:逻辑删除只解析 delete , update 操作
+     * </p>
+     *
+     * @param metaObject 元对象
+     * @return true
+     */
+    public boolean allowProcess(MetaObject metaObject) {
+        return true;
+    }
+}

+ 4 - 9
src/main/java/com/baomidou/mybatisplus/plugins/parser/AbstractSqlParser.java → src/main/java/com/baomidou/mybatisplus/plugins/parser/ISqlParser.java

@@ -15,22 +15,17 @@
  */
  */
 package com.baomidou.mybatisplus.plugins.parser;
 package com.baomidou.mybatisplus.plugins.parser;
 
 
-import org.apache.ibatis.logging.Log;
-import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.MetaObject;
 
 
 /**
 /**
  * <p>
  * <p>
- * 抽象 SQL 解析类
+ * SQL 解析接口
  * </p>
  * </p>
  *
  *
  * @author hubin
  * @author hubin
- * @Date 2017-06-20
+ * @Date 2017-09-01
  */
  */
-public abstract class AbstractSqlParser {
-
-    // 日志
-    protected final Log logger = LogFactory.getLog(this.getClass());
+public interface ISqlParser {
 
 
     /**
     /**
      * <p>
      * <p>
@@ -41,6 +36,6 @@ public abstract class AbstractSqlParser {
      * @param sql        SQL 语句
      * @param sql        SQL 语句
      * @return SQL 信息
      * @return SQL 信息
      */
      */
-    public abstract SqlInfo optimizeSql(MetaObject metaObject, String sql);
+    SqlInfo optimizeSql(MetaObject metaObject, String sql);
 
 
 }
 }

+ 24 - 36
src/main/java/com/baomidou/mybatisplus/plugins/tenancy/TenancySqlParser.java

@@ -17,19 +17,15 @@ package com.baomidou.mybatisplus.plugins.tenancy;
 
 
 import java.util.List;
 import java.util.List;
 
 
-import org.apache.ibatis.reflection.MetaObject;
-
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.AbstractJsqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 
 
-import net.sf.jsqlparser.JSQLParserException;
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.StringValue;
 import net.sf.jsqlparser.expression.StringValue;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
 import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
 import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
-import net.sf.jsqlparser.parser.CCJSqlParserUtil;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.Statement;
 import net.sf.jsqlparser.statement.Statement;
@@ -55,33 +51,24 @@ import net.sf.jsqlparser.statement.update.Update;
  * </p>
  * </p>
  *
  *
  * @author hubin
  * @author hubin
- * @since 2017-06-20
+ * @since 2017-09-01
  */
  */
-public class TenancySqlParser extends AbstractSqlParser {
+public class TenancySqlParser extends AbstractJsqlParser {
 
 
     private TenantHandler tenantHandler;
     private TenantHandler tenantHandler;
 
 
     @Override
     @Override
-    public SqlInfo optimizeSql(MetaObject metaObject, String sql) {
-        //logger.debug("old sql:{}", sql);
-        Statement stmt = null;
-        try {
-            stmt = CCJSqlParserUtil.parse(sql);
-        } catch (JSQLParserException e) {
-            //logger.debug("解析", e);
-            //logger.error("解析sql[{}]失败\n原因:{}", sql, e.getMessage());
-            //如果解析失败不进行任何处理防止业务中断
-            return null;
-        }
-        if (stmt instanceof Insert) {
-            processInsert((Insert) stmt);
-        } else if (stmt instanceof Select) {
-            processSelectBody(((Select) stmt).getSelectBody());
-        } else if (stmt instanceof Update) {
-            processUpdate((Update) stmt);
+    public SqlInfo processParser(Statement statement) {
+        if (statement instanceof Insert) {
+            this.processInsert((Insert) statement);
+        } else if (statement instanceof Select) {
+            this.processSelectBody(((Select) statement).getSelectBody());
+        } else if (statement instanceof Update) {
+            this.processUpdate((Update) statement);
+        } else if (statement instanceof Delete) {
+            this.processDelete((Delete) statement);
         }
         }
-        //logger.debug("new sql:{}", stmt);
-        return SqlInfo.newInstance().setSql(stmt.toString());
+        return SqlInfo.newInstance().setSql(statement.toString());
     }
     }
 
 
     /**
     /**
@@ -121,7 +108,7 @@ public class TenancySqlParser extends AbstractSqlParser {
             if (insert.getSelect() != null) {
             if (insert.getSelect() != null) {
                 processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
                 processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
             } else if (insert.getItemsList() != null) {
             } else if (insert.getItemsList() != null) {
-                ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue("," + this.tenantHandler.getTenantId() + ","));
+                ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(this.tenantHandler.getTenantId()));
             } else {
             } else {
                 throw new RuntimeException("无法处理的 sql");
                 throw new RuntimeException("无法处理的 sql");
             }
             }
@@ -139,12 +126,12 @@ public class TenancySqlParser extends AbstractSqlParser {
         EqualsTo equalsTo = new EqualsTo();
         EqualsTo equalsTo = new EqualsTo();
         if (where instanceof BinaryExpression) {
         if (where instanceof BinaryExpression) {
             equalsTo.setLeftExpression(new Column(this.tenantHandler.getTenantIdColumn()));
             equalsTo.setLeftExpression(new Column(this.tenantHandler.getTenantIdColumn()));
-            equalsTo.setRightExpression(new StringValue("," + tenantHandler.getTenantId() + ","));
+            equalsTo.setRightExpression(new StringValue(tenantHandler.getTenantId()));
             AndExpression andExpression = new AndExpression(equalsTo, where);
             AndExpression andExpression = new AndExpression(equalsTo, where);
             update.setWhere(andExpression);
             update.setWhere(andExpression);
         } else {
         } else {
             equalsTo.setLeftExpression(new Column(this.tenantHandler.getTenantIdColumn()));
             equalsTo.setLeftExpression(new Column(this.tenantHandler.getTenantIdColumn()));
-            equalsTo.setRightExpression(new StringValue("," + tenantHandler.getTenantId() + ","));
+            equalsTo.setRightExpression(new StringValue(tenantHandler.getTenantId()));
             update.setWhere(equalsTo);
             update.setWhere(equalsTo);
         }
         }
     }
     }
@@ -160,14 +147,16 @@ public class TenancySqlParser extends AbstractSqlParser {
     }
     }
 
 
     /**
     /**
-     * 处理PlainSelect
+     * 处理 PlainSelect
      */
      */
     public void processPlainSelect(PlainSelect plainSelect) {
     public void processPlainSelect(PlainSelect plainSelect) {
         processPlainSelect(plainSelect, false);
         processPlainSelect(plainSelect, false);
     }
     }
 
 
     /**
     /**
-     * 处理PlainSelect
+     * <p>
+     * 处理 PlainSelect
+     * </p>
      *
      *
      * @param plainSelect
      * @param plainSelect
      * @param addColumn   是否添加租户列,insert into select语句中需要
      * @param addColumn   是否添加租户列,insert into select语句中需要
@@ -179,8 +168,9 @@ public class TenancySqlParser extends AbstractSqlParser {
             Table fromTable = (Table) fromItem;
             Table fromTable = (Table) fromItem;
             if (doTableFilter(fromTable.getName())) {
             if (doTableFilter(fromTable.getName())) {
                 plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
                 plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
-                if (addColumn)
-                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column("'" + this.tenantHandler.getTenantId() + "'")));
+                if (addColumn) {
+                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(this.tenantHandler.getTenantId())));
+                }
             }
             }
         } else {
         } else {
             processFromItem(fromItem);
             processFromItem(fromItem);
@@ -237,7 +227,6 @@ public class TenancySqlParser extends AbstractSqlParser {
             if (doTableFilter(fromTable.getName())) {
             if (doTableFilter(fromTable.getName())) {
                 join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
                 join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
             }
             }
-
         }
         }
     }
     }
 
 
@@ -266,7 +255,7 @@ public class TenancySqlParser extends AbstractSqlParser {
         EqualsTo equalsTo = new EqualsTo();
         EqualsTo equalsTo = new EqualsTo();
         tenantExpression = equalsTo;
         tenantExpression = equalsTo;
         equalsTo.setLeftExpression(tenantColumn);
         equalsTo.setLeftExpression(tenantColumn);
-        equalsTo.setRightExpression(new StringValue("'" + this.tenantHandler.getTenantId() + "'"));
+        equalsTo.setRightExpression(new StringValue(this.tenantHandler.getTenantId()));
 
 
         //加入判断防止条件为空时生成 "and null" 导致查询结果为空
         //加入判断防止条件为空时生成 "and null" 导致查询结果为空
         if (expression == null) {
         if (expression == null) {
@@ -283,7 +272,6 @@ public class TenancySqlParser extends AbstractSqlParser {
             }
             }
             return new AndExpression(tenantExpression, expression);
             return new AndExpression(tenantExpression, expression);
         }
         }
-
     }
     }
 
 
     private boolean doTableFilter(String table) {
     private boolean doTableFilter(String table) {

+ 1 - 3
src/main/java/com/baomidou/mybatisplus/plugins/tenancy/TenantHandler.java

@@ -15,8 +15,6 @@
  */
  */
 package com.baomidou.mybatisplus.plugins.tenancy;
 package com.baomidou.mybatisplus.plugins.tenancy;
 
 
-import java.io.Serializable;
-
 /**
 /**
  * <p>
  * <p>
  * 租户处理器
  * 租户处理器
@@ -27,7 +25,7 @@ import java.io.Serializable;
  */
  */
 public interface TenantHandler {
 public interface TenantHandler {
 
 
-    Serializable getTenantId();
+    String getTenantId();
 
 
     String getTenantIdColumn();
     String getTenantIdColumn();
 
 

+ 3 - 3
src/main/java/com/baomidou/mybatisplus/toolkit/SqlUtils.java

@@ -16,7 +16,7 @@
 package com.baomidou.mybatisplus.toolkit;
 package com.baomidou.mybatisplus.toolkit;
 
 
 import com.baomidou.mybatisplus.enums.SqlLike;
 import com.baomidou.mybatisplus.enums.SqlLike;
-import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.ISqlParser;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
 import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
@@ -33,7 +33,7 @@ public class SqlUtils {
 
 
     private final static SqlFormatter sqlFormatter = new SqlFormatter();
     private final static SqlFormatter sqlFormatter = new SqlFormatter();
     public final static String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";
     public final static String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";
-    public static AbstractSqlParser COUNT_SQL_PARSER = null;
+    public static ISqlParser COUNT_SQL_PARSER = null;
 
 
 
 
     /**
     /**
@@ -45,7 +45,7 @@ public class SqlUtils {
      * @param originalSql 需要计算Count SQL
      * @param originalSql 需要计算Count SQL
      * @return SqlInfo
      * @return SqlInfo
      */
      */
-    public static SqlInfo getCountOptimize(AbstractSqlParser sqlParser, String originalSql) {
+    public static SqlInfo getCountOptimize(ISqlParser sqlParser, String originalSql) {
         // COUNT SQL 解析器
         // COUNT SQL 解析器
         if (null == COUNT_SQL_PARSER) {
         if (null == COUNT_SQL_PARSER) {
             if (null != sqlParser) {
             if (null != sqlParser) {

+ 1 - 3
src/test/java/com/baomidou/mybatisplus/test/SqlBuilderTest.java → src/test/java/com/baomidou/mybatisplus/test/sql/SqlBuilderTest.java

@@ -1,5 +1,4 @@
-package com.baomidou.mybatisplus.test;
-
+package com.baomidou.mybatisplus.test.sql;
 
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertEquals;
 
 
@@ -7,7 +6,6 @@ import org.apache.ibatis.jdbc.SQL;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 
 
-
 /**
 /**
  * http://www.mybatis.org/mybatis-3/zh/statement-builders.html
  * http://www.mybatis.org/mybatis-3/zh/statement-builders.html
  * <p>
  * <p>

+ 4 - 4
src/test/java/com/baomidou/mybatisplus/test/SqlUtilsTest.java → src/test/java/com/baomidou/mybatisplus/test/sql/SqlUtilsTest.java

@@ -1,10 +1,10 @@
-package com.baomidou.mybatisplus.test;
+package com.baomidou.mybatisplus.test.sql;
 
 
 import org.junit.Assert;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.Test;
 
 
-import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
 import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 
 
 /**
 /**
  * <p>
  * <p>
@@ -12,12 +12,12 @@ import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptim
  * </p>
  * </p>
  *
  *
  * @author Caratacus
  * @author Caratacus
- * @Date 2016-11-3
+ * @since 2016-11-3
  */
  */
 public class SqlUtilsTest {
 public class SqlUtilsTest {
 
 
     public SqlInfo jsqlParserCountSqlInfo(String sql) {
     public SqlInfo jsqlParserCountSqlInfo(String sql) {
-        return new JsqlParserCountOptimize().optimizeSql(sql);
+        return new JsqlParserCountOptimize().optimizeSql(null, sql);
     }
     }
 
 
     /**
     /**

+ 45 - 0
src/test/java/com/baomidou/mybatisplus/test/sql/TenancySqlTest.java

@@ -0,0 +1,45 @@
+package com.baomidou.mybatisplus.test.sql;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
+import com.baomidou.mybatisplus.plugins.tenancy.TenancySqlParser;
+import com.baomidou.mybatisplus.plugins.tenancy.TenantHandler;
+
+/**
+ * <p>
+ * 租户 SQL 测试
+ * </p>
+ *
+ * @author hubin
+ * @since 2017-09-01
+ */
+public class TenancySqlTest {
+
+    private TenancySqlParser tenancySqlParser;
+
+    @Before
+    public void setUp() throws Exception {
+        tenancySqlParser = new TenancySqlParser();
+        tenancySqlParser.setTenantHandler(new TenantHandler() {
+            @Override
+            public String getTenantId() {
+                return "1000";
+            }
+
+            @Override
+            public String getTenantIdColumn() {
+                return "tenant_id";
+            }
+        });
+    }
+
+    @Test
+    public void test1() {
+        SqlInfo sqlInfo = tenancySqlParser.optimizeSql(null, "select * from user");
+        if (null != sqlInfo) {
+            System.err.println(sqlInfo.getSql());
+        }
+    }
+}