package com.alibaba.kaola.ad.searchad.dao.interceptors;
import java.sql.Connection;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
import com.alibaba.kaola.ad.searchad.common.annotation.EnvCheck;
import com.alibaba.kaola.ad.searchad.common.utils.DateUtil;
import com.alibaba.kaola.ad.searchad.common.utils.EnvManager;
import com.alibaba.kaola.ad.searchad.common.utils.LoggerUtil;
import com.alibaba.kaola.ad.searchad.common.utils.ReflectHelper;
import com.alibaba.kaola.ad.searchad.common.utils.SpringContextUtil;
import com.alibaba.kaola.ad.searchad.common.utils.SqlUtils;
import com.alibaba.kaola.ad.searchad.dao.base.IbatisSwitch;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.core.annotation.AnnotationUtils;
/**
* @author zhenyuan.he
*/
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class IbatisPrepareInterceptor implements Interceptor {
static final String REGEX = "\\?";
static final String FIELD_DELEGATE = "delegate";
static final String FIELD_MAPPED_STATEMENT = "mappedStatement";
@Override
public Object intercept(Invocation invocation) {
BoundSql boundSql = null;
Configuration configuration = null;
SqlCommandType sqlCommandType = null;
long startTime = System.currentTimeMillis();
try {
final StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
boundSql = statementHandler.getBoundSql();
final StatementHandler handler = (StatementHandler)ReflectHelper.getFieldValue(statementHandler,
FIELD_DELEGATE);
final MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(handler,
FIELD_MAPPED_STATEMENT);
// 获取节点的配置
configuration = mappedStatement.getConfiguration();
sqlCommandType = mappedStatement.getSqlCommandType();
//1、如果不是查询语句,放过不拦截
if (!SqlCommandType.SELECT.equals(sqlCommandType)) {
return invocation.proceed();
}
//4、如果没有设置需要做空间隔离,放过不拦截
final String className = mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf("."));
final Class<?> targetMapper = Class.forName(className);
final EnvCheck envCheck = AnnotationUtils.findAnnotation(targetMapper, EnvCheck.class);
if (envCheck == null) {
return invocation.proceed();
}
final String env = SpringContextUtil.getBean(EnvManager.class).getCurrentEnv();
// 利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
ReflectHelper.setFieldValue(boundSql, "sql", replaceSql(boundSql.getSql(), env));
return invocation.proceed();
} catch (Throwable e) {
LoggerUtil.error(e, "IbatisPrepareInterceptor error");
throw new RuntimeException(e);
} finally {
try {
if (IbatisSwitch.SQL_PRINT_OPEN && sqlCommandType != null && boundSql != null) {
final String sql = showSql(configuration, boundSql);
LoggerUtil.warn(LoggerUtil.sqlLogger, "sql=", sql, ",tableName=",
String.join(",", getTableNames(boundSql.getSql())), ",sqlType=", sqlCommandType.name(),
",cost=", System.currentTimeMillis() - startTime, "ms");
}
} catch (Throwable e) {
LoggerUtil.error(e, "sql error");
}
}
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
/**
* 新的sql
*
* @param sql 原sql
* @param env 环境变量
* @return 新sql
*/
private static String replaceSql(String sql, String env) {
final String envCondition = "env='" + env + "'";
return SqlUtils.addWhereCondition(sql, envCondition);
}
/**
* 进行?的替换
*/
private String showSql(Configuration configuration, BoundSql boundSql) {
// 获取参数
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
// sql语句中多个空格都用一个空格代替
String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
if (CollectionUtils.isEmpty(parameterMappings) || parameterObject == null) {
return sql;
}
// 获取类型处理器注册器,类型处理器的功能是进行java类型和数据库类型的转换
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
// 如果根据parameterObject.getClass()可以找到对应的类型,则替换
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
sql = sql.replaceFirst(REGEX,
Matcher.quoteReplacement(getParameterValue(parameterObject)));
} else {
// MetaObject主要是封装了originalObject对象,提供了get和set的方法用于获取和设置originalObject的属性值,
// 主要支持对JavaBean、Collection、Map三种类型对象的操作
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
Object obj;
if (metaObject.hasGetter(propertyName)) {
obj = metaObject.getValue(propertyName);
} else if (boundSql.hasAdditionalParameter(propertyName)) {
// 该分支是动态sql
obj = boundSql.getAdditionalParameter(propertyName);
} else {
obj = null;
}
final String value = Matcher.quoteReplacement(getParameterValue(obj));
sql = sql.replaceFirst(REGEX, value);
}
}
return sql;
}
private static String getParameterValue(Object obj) {
if (obj == null) {
return StringUtils.EMPTY;
}
String value;
if (obj instanceof String) {
value = "'" + obj.toString() + "'";
} else if (obj instanceof Date) {
value = "'" + DateUtil.formatDate((Date)obj, DateUtil.DATETIME19_PATTERN) + "'";
} else {
value = obj.toString();
}
return value;
}
/**
* 获取表名
*
* @param sql sql
* @return 获取sql所有的表名
*/
private Set<String> getTableNames(String sql) {
Statement statement;
try {
statement = CCJSqlParserUtil.parse(sql);
} catch (JSQLParserException e) {
throw new RuntimeException("解析sql语句错误!sql:" + sql, e);
}
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
List<String> tableList = tablesNamesFinder.getTableList(statement);
Set<String> tableNames = new HashSet<>();
for (String tableName : tableList) {
//获取去掉“`”的表名
if (tableName.startsWith("`") && tableName.endsWith("`")) {
tableNames.add(tableName.substring(1, tableName.length() - 1));
} else {
tableNames.add(tableName);
}
}
return tableNames;
}
}
版权声明:本文为qq_18871751原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。