由于本人才疏学浅,刚刚入门。本文章是我在实现数据权限的过程中的学习体会。
总体思想
一、Mybatis拦截器
参考:
引用官网说明:
MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:
Executor(update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
ParameterHandler(getParameterObject, setParameters)
ResultSetHandler(handleResultSets, handleOutputParameters)
StatementHandler(prepare, parameterize, batch, update, query)
通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定了想要拦截的方法签名即可。
Mybatis所提供的功能是Plugin,虽然应译为插件,但是实质就是指的我们所需要使用的拦截器。方法及参数解析:
1. Interceptor 接口
public interface Interceptor {
Object intercept(Invocation invocation) throws Throwable;
Object plugin(Object target);
void setProperties(Properties properties);
}实现 Interceptor接口也就是实现intercept,plugin,setProperties这三个方法,其中
①intercept方法是我们拦截到对象后所进行操作的位置,也就是我们之后编写逻辑代码的位置。
②plugin方法,根据参数可以看出,该方法的作用是拦截我们需要拦截到的对象。
③setProperties方法,我们可以通过配置文件中进行properties配置,然后在该方法中读取到配置。
这三个方法的执行顺序: setProperties--->plugin--->intercept
2.intercept方法中的Invocation类的属性
private Object target; //所拦截到的目标的代理
private Method method; //所拦截目标的具体方法
private Object[] args; //方法的参数@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
public class MyInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
//逻辑代码区
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
//生成代理对象
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}解释:
@intercepts声明该类为拦截器,@signature声明拦截对象。
Mybatis获取Statement是在statementHandler中,因为我们需要拦截的对象应该是Statement,在StatementHandler类中有返回值为Statement的Prepare方法,所以,这个类就是我们需要拦截的对象。
method为我们需要拦截的prepare方法,type为所要拦截的接口类,args为prepare方法的参数。
源码解析:
StatementHandler源码:
public interface StatementHandler {
Statement prepare(Connection connection)
throws SQLException;
void parameterize(Statement statement)
throws SQLException;
void batch(Statement statement)
throws SQLException;
int update(Statement statement)
throws SQLException;
<E> List<E> query(Statement statement, ResultHandler resultHandler)
throws SQLException;
BoundSql getBoundSql();
ParameterHandler getParameterHandler();
}该源码中的prepare方法为我们需要的拦截的,它的实现为:
实际的实现方法在BaseStatementHandler中:
@Override
public Statement prepare(Connection connection) throws SQLException {
ErrorContext.instance().sql(boundSql.getSql());
Statement statement = null;
try {
statement = instantiateStatement(connection);//<-----也就是这个方法
setStatementTimeout(statement);
setFetchSize(statement);
return statement;
} catch (SQLException e) {
closeStatement(statement);
throw e;
} catch (Exception e) {
closeStatement(statement);
throw new ExecutorException("Error preparing statement. Cause: " + e, e);
}
}protected abstract Statement instantiateStatement(Connection connection) throws SQLException;该方法为抽象方法,它的实现为
由于我们的是预编译的sql,所以就是PreparedStatementHandler类中的实现方法
@Override
protected Statement instantiateStatement(Connection connection) throws SQLException {
String sql = boundSql.getSql();//<----这就是我们的sql语句
if (mappedStatement.getKeyGenerator() instanceof Jdbc3KeyGenerator) {
String[] keyColumnNames = mappedStatement.getKeyColumns();
if (keyColumnNames == null) {
return connection.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);
} else {
return connection.prepareStatement(sql, keyColumnNames);
}
} else if (mappedStatement.getResultSetType() != null) {
return connection.prepareStatement(sql, mappedStatement.getResultSetType().getValue(), ResultSet.CONCUR_READ_ONLY);
} else {
return connection.prepareStatement(sql);
}
}已经理清了sql的执行逻辑,就可以对拦截到的statementHandler进行操作了。
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler handler = (StatementHandler)invocation.getTarget();
//由于mappedStatement中有我们需要的方法id,但却是protected的,所以要通过反射获取
MetaObject statementHandler = SystemMetaObject.forObject(handler);
MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
//获取sql
BoundSql boundSql = handler.getBoundSql();
String sql = boundSql.getSql();
//获取方法id
String id = mappedStatement.getId();
if ("需要增强的方法的id".equals(id)) {
//增强sql代码块
}
return invocation.proceed();
}在以上操作完成之后不要忘了注册该拦截器
<configuration>
<plugins>
<plugin interceptor ="com.test.interceptor.MyInterceptor"/>
</plugins>
</configuration>好了,到此Mybatis拦截器的编写以及配置就到此结束,接下来需要做的就是sql解析方面(JSqlParser)的学习了. 二、JSqlParser
1.在项目添加jsqlparser依赖
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>1.0</version>
</dependency>2.解析sql
先判断sql语句的类型(SELECT,UPDATE,INSERT,DELETE.....)
根据语句类型将sql转化成相应对象
CCJSqlParserManager parserManager = new CCJSqlParserManager();3.访问各个接口实现类(SelectVisitorImpl为自己实现SelectVisitor的实现类)if ("SELECT".equals(sqlCommandType)) { Select select = (Select)parserManager.parse(new StringReader(sql)); }总体思想就是将sql语句分割成很多个小部分然后去访问各个visitor实现类.
select.getSelectBody().accept(new SelectVisitorImpl());SelectVisitorImpl.class:
SelectItemVisitorImpl.classpackage com.test.sqlparser.visitor; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.statement.select.FromItem; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitor; import net.sf.jsqlparser.statement.select.SetOperationList; import net.sf.jsqlparser.statement.select.WithItem; public class SelectVisitorImpl implements SelectVisitor { // 主要工作就是实现各种底层visitor,然后在解析的时候添加条件 // 正常的select,也就是包含全部属性的select @Override public void visit(PlainSelect plainSelect) { // 访问 select if (plainSelect.getSelectItems() != null) { for (SelectItem item : plainSelect.getSelectItems()) { item.accept(new SelectItemVisitorImpl()); } } // 访问from FromItem fromItem = plainSelect.getFromItem(); FromItemVisitorImpl fromItemVisitorImpl = new FromItemVisitorImpl(); fromItem.accept(fromItemVisitorImpl); // 访问where if (plainSelect.getWhere() != null) { plainSelect.getWhere().accept(new ExpressionVisitorImpl()); } //过滤增强的条件 if (fromItemVisitorImpl.getEnhancedCondition() != null) { if (plainSelect.getWhere() != null) { Expression expr = new Parenthesis(plainSelect.getWhere()); Expression enhancedCondition = new Parenthesis(fromItemVisitorImpl.getEnhancedCondition()); AndExpression and = new AndExpression(enhancedCondition, expr); plainSelect.setWhere(and); } else { plainSelect.setWhere(fromItemVisitorImpl.getEnhancedCondition()); } } // 访问join if (plainSelect.getJoins() != null) { for (Join join : plainSelect.getJoins()) { join.getRightItem().accept(new FromItemVisitorImpl()); } } // 访问 order by if (plainSelect.getOrderByElements() != null) { for (OrderByElement orderByElement : plainSelect .getOrderByElements()) { orderByElement.getExpression().accept( new ExpressionVisitorImpl()); } } // 访问group by having if (plainSelect.getHaving() != null) { plainSelect.getHaving().accept(new ExpressionVisitorImpl()); } } // set操作列表 @Override public void visit(SetOperationList setOpList) { for (SelectBody plainSelect : setOpList.getSelects()) { plainSelect.accept(new SelectVisitorImpl()); } } // with项 @Override public void visit(WithItem withItem) { withItem.getSelectBody().accept(new SelectVisitorImpl()); } }ExpressionVisitorImpl.classpackage com.test.sqlparser.visitor; import net.sf.jsqlparser.statement.select.AllColumns; import net.sf.jsqlparser.statement.select.AllTableColumns; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItemVisitor; public class SelectItemVisitorImpl implements SelectItemVisitor { @Override public void visit(AllColumns allColumns) { } @Override public void visit(AllTableColumns allTableColumns) { } @Override public void visit(SelectExpressionItem selectExpressionItem) { selectExpressionItem.getExpression().accept(new ExpressionVisitorImpl()); } }FromItemVisitorImpl.classpackage com.test.sqlparser.visitor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.sf.jsqlparser.expression.AllComparisonExpression; import net.sf.jsqlparser.expression.AnalyticExpression; import net.sf.jsqlparser.expression.AnyComparisonExpression; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.CaseExpression; import net.sf.jsqlparser.expression.CastExpression; import net.sf.jsqlparser.expression.DateTimeLiteralExpression; import net.sf.jsqlparser.expression.DateValue; import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitor; import net.sf.jsqlparser.expression.ExtractExpression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.HexValue; import net.sf.jsqlparser.expression.IntervalExpression; import net.sf.jsqlparser.expression.JdbcNamedParameter; import net.sf.jsqlparser.expression.JdbcParameter; import net.sf.jsqlparser.expression.JsonExpression; import net.sf.jsqlparser.expression.KeepExpression; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.MySQLGroupConcat; import net.sf.jsqlparser.expression.NullValue; import net.sf.jsqlparser.expression.NumericBind; import net.sf.jsqlparser.expression.OracleHierarchicalExpression; import net.sf.jsqlparser.expression.OracleHint; import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.RowConstructor; import net.sf.jsqlparser.expression.SignedExpression; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.TimeKeyExpression; import net.sf.jsqlparser.expression.TimeValue; import net.sf.jsqlparser.expression.TimestampValue; import net.sf.jsqlparser.expression.UserVariable; import net.sf.jsqlparser.expression.WhenClause; import net.sf.jsqlparser.expression.WithinGroupExpression; import net.sf.jsqlparser.expression.operators.arithmetic.Addition; import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseAnd; import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseOr; import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseXor; import net.sf.jsqlparser.expression.operators.arithmetic.Concat; import net.sf.jsqlparser.expression.operators.arithmetic.Division; import net.sf.jsqlparser.expression.operators.arithmetic.Modulo; import net.sf.jsqlparser.expression.operators.arithmetic.Multiplication; import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.Between; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.ExistsExpression; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.IsNullExpression; import net.sf.jsqlparser.expression.operators.relational.LikeExpression; import net.sf.jsqlparser.expression.operators.relational.Matches; import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.expression.operators.relational.RegExpMatchOperator; import net.sf.jsqlparser.expression.operators.relational.RegExpMySQLOperator; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.SubSelect; import net.sf.jsqlparser.statement.select.WithItem; public class ExpressionVisitorImpl implements ExpressionVisitor { Logger logger =LoggerFactory.getLogger(ExpressionVisitorImpl.class); // 单表达式 @Override public void visit(SignedExpression signedExpression) { signedExpression.accept(new ExpressionVisitorImpl()); } // jdbc参数 @Override public void visit(JdbcParameter jdbcParameter) { } // jdbc参数 @Override public void visit(JdbcNamedParameter jdbcNamedParameter) { } // @Override public void visit(Parenthesis parenthesis) { parenthesis.getExpression().accept(new ExpressionVisitorImpl()); } // between @Override public void visit(Between between) { between.getLeftExpression().accept(new ExpressionVisitorImpl()); between.getBetweenExpressionStart().accept(new ExpressionVisitorImpl()); between.getBetweenExpressionEnd().accept(new ExpressionVisitorImpl()); } // in表达式 @Override public void visit(InExpression inExpression) { if (inExpression.getLeftExpression() != null) { inExpression.getLeftExpression() .accept(new ExpressionVisitorImpl()); } else if (inExpression.getLeftItemsList() != null) { inExpression.getLeftItemsList().accept(new ItemsListVisitorImpl()); } inExpression.getRightItemsList().accept(new ItemsListVisitorImpl()); } // 子查询 @Override public void visit(SubSelect subSelect) { if (subSelect.getWithItemsList() != null) { for (WithItem withItem : subSelect.getWithItemsList()) { withItem.accept(new SelectVisitorImpl()); } } subSelect.getSelectBody().accept(new SelectVisitorImpl()); } // exist @Override public void visit(ExistsExpression existsExpression) { existsExpression.getRightExpression().accept( new ExpressionVisitorImpl()); } // allComparisonExpression?? @Override public void visit(AllComparisonExpression allComparisonExpression) { allComparisonExpression.getSubSelect().getSelectBody() .accept(new SelectVisitorImpl()); } // anyComparisonExpression?? @Override public void visit(AnyComparisonExpression anyComparisonExpression) { anyComparisonExpression.getSubSelect().getSelectBody() .accept(new SelectVisitorImpl()); } // oexpr?? @Override public void visit(OracleHierarchicalExpression oexpr) { if (oexpr.getStartExpression() != null) { oexpr.getStartExpression().accept(this); } if (oexpr.getConnectExpression() != null) { oexpr.getConnectExpression().accept(this); } } // rowConstructor? @Override public void visit(RowConstructor rowConstructor) { for (Expression expr : rowConstructor.getExprList().getExpressions()) { expr.accept(this); } } // cast @Override public void visit(CastExpression cast) { cast.getLeftExpression().accept(new ExpressionVisitorImpl()); } // 加法 @Override public void visit(Addition addition) { visitBinaryExpression(addition); } // 除法 @Override public void visit(Division division) { visitBinaryExpression(division); } // 乘法 @Override public void visit(Multiplication multiplication) { visitBinaryExpression(multiplication); } // 减法 @Override public void visit(Subtraction subtraction) { visitBinaryExpression(subtraction); } // and表达式 @Override public void visit(AndExpression andExpression) { visitBinaryExpression(andExpression); } // or表达式 @Override public void visit(OrExpression orExpression) { visitBinaryExpression(orExpression); } // 等式 @Override public void visit(EqualsTo equalsTo) { visitBinaryExpression(equalsTo); } // 大于 @Override public void visit(GreaterThan greaterThan) { visitBinaryExpression(greaterThan); } // 大于等于 @Override public void visit(GreaterThanEquals greaterThanEquals) { visitBinaryExpression(greaterThanEquals); } // like表达式 @Override public void visit(LikeExpression likeExpression) { visitBinaryExpression(likeExpression); } // 小于 @Override public void visit(MinorThan minorThan) { visitBinaryExpression(minorThan); } // 小于等于 @Override public void visit(MinorThanEquals minorThanEquals) { visitBinaryExpression(minorThanEquals); } // 不等于 @Override public void visit(NotEqualsTo notEqualsTo) { visitBinaryExpression(notEqualsTo); } // concat @Override public void visit(Concat concat) { visitBinaryExpression(concat); } // matches? @Override public void visit(Matches matches) { visitBinaryExpression(matches); } // bitwiseAnd位运算? @Override public void visit(BitwiseAnd bitwiseAnd) { visitBinaryExpression(bitwiseAnd); } // bitwiseOr? @Override public void visit(BitwiseOr bitwiseOr) { visitBinaryExpression(bitwiseOr); } // bitwiseXor? @Override public void visit(BitwiseXor bitwiseXor) { visitBinaryExpression(bitwiseXor); } // 取模运算modulo? @Override public void visit(Modulo modulo) { visitBinaryExpression(modulo); } // rexp?? @Override public void visit(RegExpMatchOperator rexpr) { visitBinaryExpression(rexpr); } // regExpMySQLOperator?? @Override public void visit(RegExpMySQLOperator regExpMySQLOperator) { visitBinaryExpression(regExpMySQLOperator); } // 二元表达式 public void visitBinaryExpression(BinaryExpression binaryExpression) { binaryExpression.getLeftExpression() .accept(new ExpressionVisitorImpl()); binaryExpression.getRightExpression().accept( new ExpressionVisitorImpl()); } // -------------------------下面都是没用到的----------------------------------- // aexpr?? @Override public void visit(AnalyticExpression aexpr) { } // wgexpr?? @Override public void visit(WithinGroupExpression wgexpr) { } // eexpr?? @Override public void visit(ExtractExpression eexpr) { } // iexpr?? @Override public void visit(IntervalExpression iexpr) { } // jsonExpr?? @Override public void visit(JsonExpression jsonExpr) { } // hint? @Override public void visit(OracleHint hint) { } // timeKeyExpression? @Override public void visit(TimeKeyExpression timeKeyExpression) { } // caseExpression? @Override public void visit(CaseExpression caseExpression) { } // when? @Override public void visit(WhenClause whenClause) { } // var?? @Override public void visit(UserVariable var) { } // bind? @Override public void visit(NumericBind bind) { } // aexpr? @Override public void visit(KeepExpression aexpr) { } // groupConcat? @Override public void visit(MySQLGroupConcat groupConcat) { } // table列 @Override public void visit(Column tableColumn) { } // double类型值 @Override public void visit(DoubleValue doubleValue) { } // long类型值 @Override public void visit(LongValue longValue) { } // 16进制类型值 @Override public void visit(HexValue hexValue) { } // date类型值 @Override public void visit(DateValue dateValue) { } // time类型值 @Override public void visit(TimeValue timeValue) { } // 时间戳类型值 @Override public void visit(TimestampValue timestampValue) { } // 空值 @Override public void visit(NullValue nullValue) { } // 方法 @Override public void visit(Function function) { } // 字符串类型值 @Override public void visit(StringValue stringValue) { } // is null表达式 @Override public void visit(IsNullExpression isNullExpression) { } // literal? @Override public void visit(DateTimeLiteralExpression literal) { } }package com.test.sqlparser.visitor; import java.util.ArrayList; import java.util.List; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.Between; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.GreaterThan; import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; import net.sf.jsqlparser.expression.operators.relational.IsNullExpression; import net.sf.jsqlparser.expression.operators.relational.LikeExpression; import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.FromItemVisitor; import net.sf.jsqlparser.statement.select.LateralSubSelect; import net.sf.jsqlparser.statement.select.SubJoin; import net.sf.jsqlparser.statement.select.SubSelect; import net.sf.jsqlparser.statement.select.TableFunction; import net.sf.jsqlparser.statement.select.ValuesList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.test.entity.TableCondition; import com.test.security.UserUtils; public class FromItemVisitorImpl implements FromItemVisitor { private static Logger logger = LoggerFactory .getLogger(FromItemVisitorImpl.class); // 声明增强条件 private Expression enhancedCondition; // FROM 表名 <----主要的就是这个,判断用户对这个表有没有权限 @Override public void visit(Table tableName) { //判断该表是否是需要操作的表 if (isActionTable(tableName.getFullyQualifiedName())) { //根据表名获取该用户对于该表的限制条件 List<TableCondition> test = UserUtils.getTableCondition(tableName.getFullyQualifiedName().toUpperCase()); //If the TableConditionList is exist if (test!=null) { //增强sql for (TableCondition tableCondition : test) { // 声明表达式数组 Expression[] expressions; // 如果操作符是between if ("between".equalsIgnoreCase(tableCondition.getOperator())|| "not between".equalsIgnoreCase(tableCondition.getOperator())) { //expressions = new Expression[] { new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getOperator()),new LongValue(tableCondition.getFieldValue()) }; } else if ("is null".equalsIgnoreCase(tableCondition.getOperator())|| "is not null".equalsIgnoreCase(tableCondition.getOperator())) { // 如果操作符是 is null或者是is not null的时候 //expressions = new Expression[] { new LongValue( tableCondition.getFieldName()) }; } else { // 其他情况,也就是最常用的情况,比如where 1 = 1 Column column = new Column(new Table(tableName.getAlias()!=null?tableName.getAlias().getName():tableName.getFullyQualifiedName()), tableCondition.getFieldName()); if ("1".equals(tableCondition.getFieldName())) { expressions = new Expression[] {new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getFieldValue())}; }else{ expressions = new Expression[] {column,new StringValue(tableCondition.getFieldValue())}; } } // 根据运算符对原始数据进行拼接 Expression operator = this.getOperator( tableCondition.getOperator(), expressions); if (this.enhancedCondition != null) { enhancedCondition = new AndExpression(enhancedCondition , operator); } else { enhancedCondition = operator; } } } } } // FROM 子查询 @Override public void visit(SubSelect subSelect) { // 如果是子查询的话返回到select接口实现类 subSelect.getSelectBody().accept(new SelectVisitorImpl()); } // FROM subjoin @Override public void visit(SubJoin subjoin) { subjoin.getLeft().accept(new FromItemVisitorImpl()); subjoin.getJoin().getRightItem().accept(new FromItemVisitorImpl()); } // FROM 横向子查询 @Override public void visit(LateralSubSelect lateralSubSelect) { lateralSubSelect.getSubSelect().getSelectBody() .accept(new SelectVisitorImpl()); } // FROM value列表 @Override public void visit(ValuesList valuesList) { } // FROM tableFunction @Override public void visit(TableFunction tableFunction) { } // 将字符串类型的运算符转换成数据库运算语句 private Expression getOperator(String op, Expression[] exp) { if ("=".equals(op)) { EqualsTo eq = new EqualsTo(); eq.setLeftExpression(exp[0]); eq.setRightExpression(exp[1]); return eq; } else if (">".equals(op)) { GreaterThan gt = new GreaterThan(); gt.setLeftExpression(exp[0]); gt.setRightExpression(exp[1]); return gt; } else if (">=".equals(op)) { GreaterThanEquals geq = new GreaterThanEquals(); geq.setLeftExpression(exp[0]); geq.setRightExpression(exp[1]); return geq; } else if ("<".equals(op)) { MinorThan mt = new MinorThan(); mt.setLeftExpression(exp[0]); mt.setRightExpression(exp[1]); return mt; } else if ("<=".equals(op)) { MinorThanEquals leq = new MinorThanEquals(); leq.setLeftExpression(exp[0]); leq.setRightExpression(exp[1]); return leq; } else if ("<>".equals(op)) { NotEqualsTo neq = new NotEqualsTo(); neq.setLeftExpression(exp[0]); neq.setRightExpression(exp[1]); return neq; } else if ("is null".equalsIgnoreCase(op)) { IsNullExpression isNull = new IsNullExpression(); isNull.setNot(false); isNull.setLeftExpression(exp[0]); return isNull; } else if ("is not null".equalsIgnoreCase(op)) { IsNullExpression isNull = new IsNullExpression(); isNull.setNot(true); isNull.setLeftExpression(exp[0]); return isNull; } else if ("like".equalsIgnoreCase(op)) { LikeExpression like = new LikeExpression(); like.setNot(false); like.setLeftExpression(exp[0]); like.setRightExpression(exp[1]); return like; } else if ("not like".equalsIgnoreCase(op)) { LikeExpression nlike = new LikeExpression(); nlike.setNot(true); nlike.setLeftExpression(exp[0]); nlike.setRightExpression(exp[1]); return nlike; } else if ("between".equalsIgnoreCase(op)) { Between bt = new Between(); bt.setNot(false); bt.setLeftExpression(exp[0]); bt.setBetweenExpressionStart(exp[1]); bt.setBetweenExpressionEnd(exp[2]); return bt; } else if ("not between".equalsIgnoreCase(op)) { Between bt = new Between(); bt.setNot(true); bt.setLeftExpression(exp[0]); bt.setBetweenExpressionStart(exp[1]); bt.setBetweenExpressionEnd(exp[2]); return bt; } else { // 如果没有该运算符对应的语句 return null; } } public Expression getEnhancedCondition() { return enhancedCondition; } // 判断传入的table是否是要进行操作的table public boolean isActionTable(String tableName) { // 默认为操作 boolean flag = true; // 无需操作的表的表名 List<String> tableNames = new ArrayList<String>(); // 由于sql可能格式不规范可能表名会存在小写,故全部转换成大写,最上面的方法一样 if (tableNames.contains(tableName.toUpperCase())) { // 如果表名在过滤条件中则将flag改为flase flag = false; } return flag; } }完整的拦截器代码
package com.test.interceptor; import java.io.StringReader; import java.sql.Connection; import java.util.Properties; import net.sf.jsqlparser.parser.CCJSqlParserManager; import net.sf.jsqlparser.statement.Statement; import net.sf.jsqlparser.statement.select.Select; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import com.test.sqlparser.visitor.SelectVisitorImpl; @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) }) public class MyInterceptor implements Interceptor { CCJSqlParserManager parserManager = new CCJSqlParserManager(); @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler handler = (StatementHandler)invocation.getTarget(); //由于mappedStatement为protected的,所以要通过反射获取 MetaObject statementHandler = SystemMetaObject.forObject(handler); //mappedStatement中有我们需要的方法id MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement"); //获取sql BoundSql boundSql = handler.getBoundSql(); String sql = boundSql.getSql(); //获取方法id String id = mappedStatement.getId(); //获得方法类型 SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType(); if ("需要增强的方法的id".equals(id)) { //增强sql代码块 if ("SELECT".equals(sqlCommandType)) { //如果是select就将sql转成SELECT对象 Select select = (Select)parserManager.parse(new StringReader(sql)); //访问各个visitor select.getSelectBody().accept(new SelectVisitorImpl()); //将增强后的sql放回 statementHandler.setValue("delegate.boundSql.sql",select.toString()); } } return invocation.proceed(); } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { } }