diff --git a/tang-framework/src/main/java/com/tang/framework/interceptor/DictPermissionInterceptor.java b/tang-framework/src/main/java/com/tang/framework/interceptor/DictPermissionInterceptor.java index 4ee94f03..c7743676 100644 --- a/tang-framework/src/main/java/com/tang/framework/interceptor/DictPermissionInterceptor.java +++ b/tang-framework/src/main/java/com/tang/framework/interceptor/DictPermissionInterceptor.java @@ -22,6 +22,7 @@ import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.lang.Nullable; import com.tang.commons.annotation.DictPermission; import com.tang.commons.model.SysDictDataModel; @@ -43,12 +44,11 @@ }) public class DictPermissionInterceptor implements Interceptor { - private static final String LIMIT = "LIMIT"; - private final RedisUtils redisUtils = SpringUtils.getBean(RedisUtils.class); @Override public Object intercept(Invocation invocation) throws Throwable { + // 如果未登录,则不进行字典数据权限过滤 if (!SecurityUtils.isAuthenticated()) { return invocation.proceed(); } @@ -65,7 +65,9 @@ public Object intercept(Invocation invocation) throws Throwable { var roleDictDataMap = SecurityUtils.getDictPermissions(); + // 获取所有有字典数据权限的字段 var fields = new LinkedHashMap(); + // 获取表名 var tableName = findTableName(originalSql); if (mappedStatement.getId().endsWith("_COUNT")) { @@ -74,12 +76,23 @@ public Object intercept(Invocation invocation) throws Throwable { fields.putAll(getFields(mappedStatement.getResultMaps().get(0).getType())); } + // 如果没有字典数据权限字段,则不进行字典数据权限过滤 + if (fields.isEmpty()) { + return invocation.proceed(); + } + + // 拼接字典数据权限 SQL var extraSql = new StringBuilder(); fields.forEach((field, dictPermission) -> { var dictDataList = selectDictDataListByDictType(dictPermission.name()); var alias = findTableAlias(originalSql, tableName); - extraSql.append(format(" and {}.", alias)) - .append(field.getName()).append(" in ("); + extraSql.append(" and "); + if (Objects.nonNull(alias)) { + extraSql.append(alias) + .append("."); + } + extraSql.append(field.getName()) + .append(" in ("); var roleDictDataList = roleDictDataMap.get(dictPermission.name()); extraSql.append(dictDataList.stream() .filter(dictData -> roleDictDataList.contains(dictData.getDataValue())) @@ -88,6 +101,7 @@ public Object intercept(Invocation invocation) throws Throwable { extraSql.append(") "); }); + // 插入字典数据权限 SQL var keywordList = List.of("group by", "having", "order by", "limit"); var indexKeyword = keywordList.stream() .filter(keyword -> StringUtils.containsIgnoreCase(originalSql, keyword)) @@ -115,6 +129,12 @@ public Object intercept(Invocation invocation) throws Throwable { return executor.query(mappedStatement, parameter, rowBounds, resultHandler, cacheKey, boundSql); } + /** + * 根据字典类型查询字典数据集合 + * + * @param dictType 字典类型 + * @return 字典数据集合 + */ private List selectDictDataListByDictType(String dictType) { var dictDataList = redisUtils.get(DICT_TYPE + dictType); if (dictDataList instanceof List list) { @@ -139,20 +159,35 @@ private static LinkedHashMap getFields(Class clazz .collect(Collectors.toMap(field -> field, field -> AnnotationUtils.getAnnotation(field, DictPermission.class), (k1, k2) -> k1, LinkedHashMap::new)); } + /** + * 获取表名 + * + * @param sql SQL 语句 + * @return 表名 + */ + @Nullable private static String findTableName(String sql) { final var pattern = Pattern.compile("(?i)\\bfrom\\b\\s+(\\w+)"); final var matcher = pattern.matcher(sql); if (!matcher.find()) { - return StringUtils.EMPTY; + return null; } return matcher.group(1); } + /** + * 获取表别名 + * + * @param sql SQL 语句 + * @param tableName 表名 + * @return 表别名 + */ + @Nullable private static String findTableAlias(String sql, String tableName) { final var pattern = Pattern.compile("(?i)\\b" + tableName + "\\b\\s+(\\w+)"); final var matcher = pattern.matcher(sql); if (!matcher.find()) { - return StringUtils.EMPTY; + return null; } return matcher.group(1); }