mybatis-plus框架,通过java bean 生成 sql文件

mybatis-plus 通过java bean生成 sql
后端技术

背景

由于项目背景,导致数据库和代码的版本存在差异,大概有一百多表,集中表现在数据库缺表和缺字段,代码Entity采用驼峰形式,而数据库使用的下划线方式, 如果采用人工一个个对的话,费时费力,并且还有遗漏
项目采用的框架是spring boot+ mybaits plus方式,mybaits plus也是驼峰转下划线,并且自动做了映射,这么一来,就可以写脚本方式,生成diff sql了
以后也可以直接先写entity ,然后生成sql,对于java开发者更加方便

具体实现

Spring 容器工具类,用于获取容器的bean

@Component
public class SpringContextUtils implements ApplicationContextAware {
    public static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext)
        throws BeansException {
    SpringContextUtils.applicationContext = applicationContext;
}

public static Object getBean(String name) {
    return applicationContext.getBean(name);
}

public static <T> T getBean(String name, Class<T> requiredType) {
    return applicationContext.getBean(name, requiredType);
}

public static boolean containsBean(String name) {
    return applicationContext.containsBean(name);
}

public static boolean isSingleton(String name) {
    return applicationContext.isSingleton(name);
}

public static Class<? extends Object> getType(String name) {
    return applicationContext.getType(name);
}

}

SQl生成工具类

public class GenerateSqlMapperUtil {


    /**
     * 生成建表语句
     *
     * @param c Entity的类名
     * @return
     * @throws IOException
     */
    public static String generateCreateTable(Class<?> c) throws IOException {
        TableInfo tableInfo = SqlHelper.table(c);

        StringBuilder sql = new StringBuilder();
        sql.append("create table ").append(tableInfo.getTableName()).append("( \r\n");
        for (TableFieldInfo f : tableInfo.getFieldList()) {
            sql.append(f.getColumn()).append(" ");
            sql.append(getJDBCType(f));
            //如果是主键的话
            if (tableInfo.getKeyColumn().equals(f.getColumn())) {
                // 默认第一个字段为ID主键
                sql.append(" PRIMARY KEY AUTO_INCREMENT");
            }
            sql.append(",\n");
        }
        sql.delete(sql.lastIndexOf(","), sql.length()).append("\n)ENGINE=INNODB DEFAULT CHARSET=UTF8MB4 AUTO_INCREMENT=1;\r\n");
        return sql.toString();

    }

    /**
     * 获取数据库类型
     *
     * @param
     * @return
     */
    public static String getJDBCType(TableFieldInfo f) {
        Class<?> type = f.getField().getType();
        if (type.equals(Integer.class)) {
            return "INTEGER";
        } else if (type.equals(Long.class)) {
            return "BIGINT";
        } else if (type.equals(java.math.BigDecimal.class)) {
            // 根据需要自行修改
            return "decimal(16,8)";
        } else if (type.equals(Date.class)) {
            return "datetime";
        } else if (type.equals(Short.class)) {
            return "smallint";
        } else if (type.equals(Byte.class)) {
            return "tinyint";
        } else {
            //兜底用VARCHAR,如有需要可以扩展下去
            return "VARCHAR(100)";
        }

    }

    /**
     * 根据 mapper 获取 entity
     * 比如 public interface MemberLevelMapper extends BaseMapper<MemberLevel> 的entity 是MemberLevel
     *
     * @param
     */
    public static Class<?> getActualType(Class<?> c) throws ClassNotFoundException {

        //只有一个继承,选择第一个
        ParameterizedType parameterizedType = (ParameterizedType) c.getGenericInterfaces()[0];
        //,只有一个范型,需要选择第一个
        Type actualTypeArgument = parameterizedType.getActualTypeArguments()[0];
        return Class.forName(actualTypeArgument.getTypeName());
    }


}

修复主流程的工具类

/**
 * @author cc
 */
@Log4j2
public class FixUtils {


    public static String fix() throws ClassNotFoundException, IOException {

        //扫描所有的spring 容器的所有的Mapper
        Map<String, Object> beansWithAnnotation = SpringContextUtils.applicationContext.getBeansWithAnnotation(Mapper.class);

        StringBuilder stringBuilder = new StringBuilder();

        for (Map.Entry<String, Object> stringObjectEntry : beansWithAnnotation.entrySet()) {
            Object value = stringObjectEntry.getValue();
            Class<?> mapperClassType = SpringContextUtils.applicationContext.getType(stringObjectEntry.getKey());
            try {
                //通过反射调用selectById方法,来测试缺表或者缺字段
                value.getClass().getMethod("selectById", java.io.Serializable.class).invoke(value, "1");

            } catch (Exception e) {
                String message = ((InvocationTargetException) e).getTargetException().getMessage();
                Class<?> entityClass = GenerateSqlMapperUtil.getActualType(mapperClassType);
                TableInfo tableInfo = SqlHelper.table(entityClass);
                List<TableFieldInfo> fieldList = tableInfo.getFieldList();

                //缺表的情况
                if (message.contains("Table") && message.contains("doesn't exist")) {

                    stringBuilder.append(GenerateSqlMapperUtil.generateCreateTable(entityClass)).append(" \r\n");

                    //缺字段的情况
                } else if (message.contains("column") && message.contains("field list")) {
                    String[] split = message.split("'");
                    //缺少哪一个字段
                    String column = split[1];
                    Optional<TableFieldInfo> first = fieldList.stream().filter(t -> t.getColumn().equals(column)).findFirst();
                    if (first.isPresent()) {
                        TableFieldInfo tableFieldInfo = first.get();
                        String columnType = GenerateSqlMapperUtil.getJDBCType(tableFieldInfo);
                        //此处可以读取java doc注释,生成sql注释
                        String str = "alter table %s add %s %s null;";
                        String format = String.format(str, tableInfo.getTableName(), column, columnType);
                        stringBuilder.append(format).append("\r\n");
                    } else {
                        log.error("未能获取缺少的字段名:{},{}", tableInfo.getTableName(), message);
                    }
                } else {
                    log.error("无法处理此情况:{},{}", tableInfo.getTableName(), message);

                }
            }

        }

        return stringBuilder.toString();


    }

    /**
     * 写入文件
     */
    public static void writeSql() throws IOException, ClassNotFoundException {
        PrintStream out = new PrintStream(new FileOutputStream("/tmp/and.sql"));
        out.println(fix());
    }


}

怎么使用

  1. 因为需要使用spring 容器, 可以在Controller里面触发 FixUtils.writeSql();,也可以job里面调用
  2. 生成的默认文件在/tmp/and.sql, 直接用mysql 客户端source and.sql;

后记

  1. 如果一个表缺多个字段的话,需要多次生成
  2. 生成的字段,类型并不是最合适的,没有默认值,没有注释,需要手动改下
  3. 代码并没有考虑太多的异常情况,如果有不懂的,可以联系我
bigcong