TenancySqlParser.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. /**
  2. * Copyright (c) 2011-2020, hubin (jobob@qq.com).
  3. * <p>
  4. * Licensed under the Apache License, Version 2.0 (the "License"); you may not
  5. * use this file except in compliance with the License. You may obtain a copy of
  6. * the License at
  7. * <p>
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. * <p>
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  12. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  13. * License for the specific language governing permissions and limitations under
  14. * the License.
  15. */
  16. package com.baomidou.mybatisplus.plugins.tenancy;
  17. import java.util.ArrayList;
  18. import java.util.List;
  19. import com.baomidou.mybatisplus.parser.AbstractSqlParser;
  20. import com.baomidou.mybatisplus.parser.SqlInfo;
  21. import net.sf.jsqlparser.JSQLParserException;
  22. import net.sf.jsqlparser.expression.BinaryExpression;
  23. import net.sf.jsqlparser.expression.Expression;
  24. import net.sf.jsqlparser.expression.StringValue;
  25. import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
  26. import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
  27. import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
  28. import net.sf.jsqlparser.expression.operators.relational.InExpression;
  29. import net.sf.jsqlparser.parser.CCJSqlParserUtil;
  30. import net.sf.jsqlparser.schema.Column;
  31. import net.sf.jsqlparser.schema.Table;
  32. import net.sf.jsqlparser.statement.Statement;
  33. import net.sf.jsqlparser.statement.delete.Delete;
  34. import net.sf.jsqlparser.statement.insert.Insert;
  35. import net.sf.jsqlparser.statement.select.FromItem;
  36. import net.sf.jsqlparser.statement.select.Join;
  37. import net.sf.jsqlparser.statement.select.LateralSubSelect;
  38. import net.sf.jsqlparser.statement.select.PlainSelect;
  39. import net.sf.jsqlparser.statement.select.Select;
  40. import net.sf.jsqlparser.statement.select.SelectBody;
  41. import net.sf.jsqlparser.statement.select.SelectExpressionItem;
  42. import net.sf.jsqlparser.statement.select.SetOperationList;
  43. import net.sf.jsqlparser.statement.select.SubJoin;
  44. import net.sf.jsqlparser.statement.select.SubSelect;
  45. import net.sf.jsqlparser.statement.select.ValuesList;
  46. import net.sf.jsqlparser.statement.select.WithItem;
  47. import net.sf.jsqlparser.statement.update.Update;
  48. /**
  49. * <p>
  50. * 租户 SQL 解析
  51. * </p>
  52. *
  53. * @author hubin
  54. * @since 2017-06-20
  55. */
  56. public class TenancySqlParser extends AbstractSqlParser {
  57. private TenantInfo tenantInfo;
  58. @Override
  59. public SqlInfo optimizeSql(String sql) {
  60. //logger.debug("old sql:{}", sql);
  61. Statement stmt = null;
  62. try {
  63. stmt = CCJSqlParserUtil.parse(sql);
  64. } catch (JSQLParserException e) {
  65. //logger.debug("解析", e);
  66. //logger.error("解析sql[{}]失败\n原因:{}", sql, e.getMessage());
  67. //如果解析失败不进行任何处理防止业务中断
  68. return null;
  69. }
  70. if (stmt instanceof Insert) {
  71. processInsert((Insert) stmt);
  72. } else if (stmt instanceof Select) {
  73. processSelectBody(((Select) stmt).getSelectBody());
  74. } else if (stmt instanceof Update) {
  75. processUpdate((Update) stmt);
  76. }
  77. //logger.debug("new sql:{}", stmt);
  78. SqlInfo sqlInfo = SqlInfo.newInstance();
  79. sqlInfo.setSql(stmt.toString());
  80. return sqlInfo;
  81. }
  82. /**
  83. * <p>
  84. * select 语句处理
  85. * </p>
  86. */
  87. public void processSelectBody(SelectBody selectBody) {
  88. if (selectBody instanceof PlainSelect) {
  89. processPlainSelect((PlainSelect) selectBody);
  90. } else if (selectBody instanceof WithItem) {
  91. WithItem withItem = (WithItem) selectBody;
  92. if (withItem.getSelectBody() != null) {
  93. processSelectBody(withItem.getSelectBody());
  94. }
  95. } else {
  96. SetOperationList operationList = (SetOperationList) selectBody;
  97. if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
  98. List<SelectBody> plainSelects = operationList.getSelects();
  99. for (SelectBody plainSelect : plainSelects) {
  100. processSelectBody(plainSelect);
  101. }
  102. }
  103. }
  104. }
  105. /**
  106. * <p>
  107. * insert 语句处理
  108. * </p>
  109. */
  110. public void processInsert(Insert insert) {
  111. if (doTableFilter(
  112. insert.getTable().getName()
  113. )) {
  114. insert.getColumns().add(new Column(this.tenantInfo.getTenantIdColumn()));
  115. if (insert.getSelect() != null) {
  116. processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
  117. } else if (insert.getItemsList() != null) {
  118. ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue("," + this.tenantInfo.getTenantId() + ","));
  119. } else {
  120. //
  121. throw new RuntimeException("无法处理的 sql");
  122. }
  123. }
  124. }
  125. /**
  126. * <p>
  127. * update 语句处理
  128. * </p>
  129. */
  130. public void processUpdate(Update update) {
  131. //获得where条件表达式
  132. Expression where = update.getWhere();
  133. EqualsTo equalsTo = new EqualsTo();
  134. if (where instanceof BinaryExpression) {
  135. equalsTo.setLeftExpression(new Column(this.tenantInfo.getTenantIdColumn()));
  136. equalsTo.setRightExpression(new StringValue("," + tenantInfo.getTenantId() + ","));
  137. AndExpression andExpression = new AndExpression(equalsTo, where);
  138. update.setWhere(andExpression);
  139. } else {
  140. equalsTo.setLeftExpression(new Column(this.tenantInfo.getTenantIdColumn()));
  141. equalsTo.setRightExpression(new StringValue("," + tenantInfo.getTenantId() + ","));
  142. update.setWhere(equalsTo);
  143. }
  144. }
  145. /**
  146. * <p>
  147. * delete 语句处理
  148. * </p>
  149. */
  150. public void processDelete(Delete delete) {
  151. }
  152. /**
  153. * 处理PlainSelect
  154. */
  155. public void processPlainSelect(PlainSelect plainSelect) {
  156. processPlainSelect(plainSelect, false);
  157. }
  158. /**
  159. * 处理PlainSelect
  160. *
  161. * @param plainSelect
  162. * @param addColumn 是否添加租户列,insert into select语句中需要
  163. */
  164. public void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
  165. FromItem fromItem = plainSelect.getFromItem();
  166. if (fromItem instanceof Table) {
  167. Table fromTable = (Table) fromItem;
  168. if (doTableFilter(fromTable.getName())) {
  169. plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
  170. if (addColumn)
  171. plainSelect.getSelectItems().add(new SelectExpressionItem(new Column("'" + this.tenantInfo.getTenantId() + "'")));
  172. }
  173. } else {
  174. processFromItem(fromItem);
  175. }
  176. List<Join> joins = plainSelect.getJoins();
  177. if (joins != null && joins.size() > 0) {
  178. for (Join join : joins) {
  179. processJoin(join);
  180. processFromItem(join.getRightItem());
  181. }
  182. }
  183. }
  184. /**
  185. * 处理子查询等
  186. *
  187. * @param fromItem
  188. */
  189. public void processFromItem(FromItem fromItem) {
  190. if (fromItem instanceof SubJoin) {
  191. SubJoin subJoin = (SubJoin) fromItem;
  192. if (subJoin.getJoin() != null) {
  193. processJoin(subJoin.getJoin());
  194. }
  195. if (subJoin.getLeft() != null) {
  196. processFromItem(subJoin.getLeft());
  197. }
  198. } else if (fromItem instanceof SubSelect) {
  199. SubSelect subSelect = (SubSelect) fromItem;
  200. if (subSelect.getSelectBody() != null) {
  201. processSelectBody(subSelect.getSelectBody());
  202. }
  203. } else if (fromItem instanceof ValuesList) {
  204. } else if (fromItem instanceof LateralSubSelect) {
  205. LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
  206. if (lateralSubSelect.getSubSelect() != null) {
  207. SubSelect subSelect = lateralSubSelect.getSubSelect();
  208. if (subSelect.getSelectBody() != null) {
  209. processSelectBody(subSelect.getSelectBody());
  210. }
  211. }
  212. }
  213. }
  214. /**
  215. * 处理联接语句
  216. *
  217. * @param join
  218. */
  219. public void processJoin(Join join) {
  220. if (join.getRightItem() instanceof Table) {
  221. Table fromTable = (Table) join.getRightItem();
  222. if (doTableFilter(fromTable.getName())) {
  223. join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
  224. }
  225. }
  226. }
  227. /**
  228. * 处理条件
  229. * TODO 未解决sql注入问题(考虑替换StringValue为LongValue),因为线上数据库租户字段为int暂时不存在注入问题
  230. *
  231. * @param expression
  232. * @param table
  233. * @return
  234. */
  235. public Expression builderExpression(Expression expression, Table table) {
  236. Expression tenantExpression = null;
  237. String[] tenantIds = this.tenantInfo.getTenantId().split(",");
  238. //当传入table时,字段前加上别名或者table名
  239. //别名优先使用
  240. StringBuilder tenantIdColumnName = new StringBuilder();
  241. if (table != null) {
  242. tenantIdColumnName.append(table.getAlias() != null ? table.getAlias().getName() : table.getName());
  243. tenantIdColumnName.append(".");
  244. }
  245. tenantIdColumnName.append(this.tenantInfo.getTenantIdColumn());
  246. //生成字段名
  247. Column tenantColumn = new Column(tenantIdColumnName.toString());
  248. if (tenantIds.length == 1) {
  249. EqualsTo equalsTo = new EqualsTo();
  250. tenantExpression = equalsTo;
  251. equalsTo.setLeftExpression(tenantColumn);
  252. equalsTo.setRightExpression(new StringValue("'" + tenantIds[0] + "'"));
  253. } else {
  254. //多租户身份
  255. InExpression inExpression = new InExpression();
  256. tenantExpression = inExpression;
  257. inExpression.setLeftExpression(tenantColumn);
  258. List<Expression> valueList = new ArrayList<>();
  259. for (String tid : tenantIds) {
  260. valueList.add(new StringValue("'" + tid + "'"));
  261. }
  262. inExpression.setRightItemsList(new ExpressionList(valueList));
  263. }
  264. //加入判断防止条件为空时生成 "and null" 导致查询结果为空
  265. if (expression == null) {
  266. return tenantExpression;
  267. } else {
  268. if (expression instanceof BinaryExpression) {
  269. BinaryExpression binaryExpression = (BinaryExpression) expression;
  270. if (binaryExpression.getLeftExpression() instanceof FromItem) {
  271. processFromItem((FromItem) binaryExpression.getLeftExpression());
  272. }
  273. if (binaryExpression.getRightExpression() instanceof FromItem) {
  274. processFromItem((FromItem) binaryExpression.getRightExpression());
  275. }
  276. }
  277. return new AndExpression(tenantExpression, expression);
  278. }
  279. }
  280. private boolean doTableFilter(String table) {
  281. return true;
  282. }
  283. }