Skip to content

Commit

Permalink
修复没有表别名语法错误问题
Browse files Browse the repository at this point in the history
  • Loading branch information
tangllty committed Nov 16, 2023
1 parent f451a2a commit 6951ffe
Showing 1 changed file with 41 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.lang.Nullable;

import com.tang.commons.annotation.DictPermission;
import com.tang.commons.model.SysDictDataModel;
Expand All @@ -43,12 +44,11 @@
})
public class DictPermissionInterceptor implements Interceptor {

private static final String LIMIT = "LIMIT";

private final RedisUtils redisUtils = SpringUtils.getBean(RedisUtils.class);

@Override
public Object intercept(Invocation invocation) throws Throwable {
// 如果未登录,则不进行字典数据权限过滤
if (!SecurityUtils.isAuthenticated()) {
return invocation.proceed();
}
Expand All @@ -65,7 +65,9 @@ public Object intercept(Invocation invocation) throws Throwable {

var roleDictDataMap = SecurityUtils.getDictPermissions();

// 获取所有有字典数据权限的字段
var fields = new LinkedHashMap<Field, DictPermission>();
// 获取表名
var tableName = findTableName(originalSql);

if (mappedStatement.getId().endsWith("_COUNT")) {
Expand All @@ -74,12 +76,23 @@ public Object intercept(Invocation invocation) throws Throwable {
fields.putAll(getFields(mappedStatement.getResultMaps().get(0).getType()));
}

// 如果没有字典数据权限字段,则不进行字典数据权限过滤
if (fields.isEmpty()) {
return invocation.proceed();
}

// 拼接字典数据权限 SQL
var extraSql = new StringBuilder();
fields.forEach((field, dictPermission) -> {
var dictDataList = selectDictDataListByDictType(dictPermission.name());
var alias = findTableAlias(originalSql, tableName);
extraSql.append(format(" and {}.", alias))
.append(field.getName()).append(" in (");
extraSql.append(" and ");
if (Objects.nonNull(alias)) {
extraSql.append(alias)
.append(".");
}
extraSql.append(field.getName())
.append(" in (");
var roleDictDataList = roleDictDataMap.get(dictPermission.name());
extraSql.append(dictDataList.stream()
.filter(dictData -> roleDictDataList.contains(dictData.getDataValue()))
Expand All @@ -88,6 +101,7 @@ public Object intercept(Invocation invocation) throws Throwable {
extraSql.append(") ");
});

// 插入字典数据权限 SQL
var keywordList = List.of("group by", "having", "order by", "limit");
var indexKeyword = keywordList.stream()
.filter(keyword -> StringUtils.containsIgnoreCase(originalSql, keyword))
Expand Down Expand Up @@ -115,6 +129,12 @@ public Object intercept(Invocation invocation) throws Throwable {
return executor.query(mappedStatement, parameter, rowBounds, resultHandler, cacheKey, boundSql);
}

/**
* 根据字典类型查询字典数据集合
*
* @param dictType 字典类型
* @return 字典数据集合
*/
private List<SysDictDataModel> selectDictDataListByDictType(String dictType) {
var dictDataList = redisUtils.get(DICT_TYPE + dictType);
if (dictDataList instanceof List<?> list) {
Expand All @@ -139,20 +159,35 @@ private static <T> LinkedHashMap<Field, DictPermission> getFields(Class<T> clazz
.collect(Collectors.toMap(field -> field, field -> AnnotationUtils.getAnnotation(field, DictPermission.class), (k1, k2) -> k1, LinkedHashMap::new));
}

/**
* 获取表名
*
* @param sql SQL 语句
* @return 表名
*/
@Nullable
private static String findTableName(String sql) {
final var pattern = Pattern.compile("(?i)\\bfrom\\b\\s+(\\w+)");
final var matcher = pattern.matcher(sql);
if (!matcher.find()) {
return StringUtils.EMPTY;
return null;
}
return matcher.group(1);
}

/**
* 获取表别名
*
* @param sql SQL 语句
* @param tableName 表名
* @return 表别名
*/
@Nullable
private static String findTableAlias(String sql, String tableName) {
final var pattern = Pattern.compile("(?i)\\b" + tableName + "\\b\\s+(\\w+)");
final var matcher = pattern.matcher(sql);
if (!matcher.find()) {
return StringUtils.EMPTY;
return null;
}
return matcher.group(1);
}
Expand Down

0 comments on commit 6951ffe

Please sign in to comment.