Browse Source

:sparkles: 自动判断数据库类型适配一些特殊情况

Fix https://github.com/baomidou/mybatis-plus/pull/811
Cat73 6 years ago
parent
commit
5ca013fd15

+ 41 - 35
mybatis-plus-generator/src/main/java/com/baomidou/mybatisplus/generator/config/DataSourceConfig.java

@@ -15,28 +15,17 @@
  */
 package com.baomidou.mybatisplus.generator.config;
 
-import java.sql.Connection;
-import java.sql.DriverManager;
-import java.sql.SQLException;
-
 import com.baomidou.mybatisplus.annotation.DbType;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
-import com.baomidou.mybatisplus.generator.config.converts.DB2TypeConvert;
-import com.baomidou.mybatisplus.generator.config.converts.MySqlTypeConvert;
-import com.baomidou.mybatisplus.generator.config.converts.OracleTypeConvert;
-import com.baomidou.mybatisplus.generator.config.converts.PostgreSqlTypeConvert;
-import com.baomidou.mybatisplus.generator.config.converts.SqlServerTypeConvert;
-import com.baomidou.mybatisplus.generator.config.querys.DB2Query;
-import com.baomidou.mybatisplus.generator.config.querys.H2Query;
-import com.baomidou.mybatisplus.generator.config.querys.MariadbQuery;
-import com.baomidou.mybatisplus.generator.config.querys.MySqlQuery;
-import com.baomidou.mybatisplus.generator.config.querys.OracleQuery;
-import com.baomidou.mybatisplus.generator.config.querys.PostgreSqlQuery;
-import com.baomidou.mybatisplus.generator.config.querys.SqlServerQuery;
-
+import com.baomidou.mybatisplus.generator.config.converts.*;
+import com.baomidou.mybatisplus.generator.config.querys.*;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.SQLException;
+
 /**
  * 数据库配置
  *
@@ -116,26 +105,43 @@ public class DataSourceConfig {
      * @return 类型枚举值
      */
     public DbType getDbType() {
-        if (null == dbType) {
-            if (driverName.contains("mysql")) {
-                dbType = DbType.MYSQL;
-            } else if (driverName.contains("oracle")) {
-                dbType = DbType.ORACLE;
-            } else if (driverName.contains("postgresql")) {
-                dbType = DbType.POSTGRE_SQL;
-            } else if (driverName.contains("sqlserver")) {
-                dbType = DbType.SQL_SERVER;
-            } else if (driverName.contains("db2")) {
-                dbType = DbType.DB2;
-            } else if (driverName.contains("mariadb")) {
-                dbType = DbType.MARIADB;
-            } else if(driverName.contains("h2")){
-                dbType = DbType.H2;
-            }else {
-                throw ExceptionUtils.mpe("Unknown type of database!");
+        if (null == this.dbType) {
+            this.dbType = this.getDbType(this.driverName);
+            if (null == this.dbType) {
+                this.dbType = this.getDbType(this.url.toLowerCase());
+                if (null == this.dbType) {
+                    throw ExceptionUtils.mpe("Unknown type of database!");
+                }
             }
         }
-        return dbType;
+
+        return this.dbType;
+    }
+
+    /**
+     * 判断数据库类型
+     *
+     * @param str 用于寻找特征的字符串,可以是 driverName 或小写后的 url
+     * @return 类型枚举值,如果没找到,则返回 null
+     */
+    private DbType getDbType(String str) {
+        if (str.contains("mysql")) {
+            return DbType.MYSQL;
+        } else if (str.contains("oracle")) {
+            return DbType.ORACLE;
+        } else if (str.contains("postgresql")) {
+            return DbType.POSTGRE_SQL;
+        } else if (str.contains("sqlserver")) {
+            return DbType.SQL_SERVER;
+        } else if (str.contains("db2")) {
+            return DbType.DB2;
+        } else if (str.contains("mariadb")) {
+            return DbType.MARIADB;
+        } else if (str.contains("h2")) {
+            return DbType.H2;
+        } else {
+            return null;
+        }
     }
 
     public ITypeConvert getTypeConvert() {