superior-sql-parser icon indicating copy to clipboard operation
superior-sql-parser copied to clipboard

这个项目支持校验某个用户对某个表是否具有增删改权限吗?

Open KuanKuanya opened this issue 1 year ago • 1 comments

KuanKuanya avatar Jul 25 '24 08:07 KuanKuanya

只是解析sql,可以结合解析出来的信息,自己实现权限校验。例如我们的实现:

/**
 * @TODO
 * 1、支持存储过程权限校验
 * 2、支持flink cdc 权限校验
 */
@Service
public class AuthorizationService {

    private static final Logger LOG = LoggerFactory.getLogger(AuthorizationService.class);

    @Autowired
    private TableAccessLogService tableAccessLogService;

    @Autowired
    private UserInfoService userInfoService;

    @Autowired
    private SuperiorBeeConfigClient configClient;

    @Autowired
    private FunctionService functionService;

    @Autowired
    private DataSourceService dataSourceService;

    @Autowired
    private TableService tableService;

    @Autowired
    private WorkspaceService workspaceService;

    @Autowired
    private SecTablePrivsService tablePrivsService;

    @Autowired
    private SecDatabasePrivsService databasePrivsService;

    @Transactional
    public void checkAuthority(AuthContext context, String sql, String[] sparkTempTables) {
        String userId = context.getUserId();
        if (StringUtils.isBlank(context.getUserId())) {
            throw new IllegalArgumentException("userId can not empty");
        }

        Statement statement = SparkSqlHelper.parseStatement(sql);
        boolean supportedSql = SparkSqlHelper.checkSupportedSQL(statement.getStatementType());
        if (!supportedSql) {
            throw new SQLParserException("not support sql: " + sql);
        }

        String[] superTableOwners = configClient.getStringArray(SuperiorConf.SUPERIOR_SKIP_TABLE_AUTH_CHECK_USERS);
        if (!ArrayUtils.contains(superTableOwners, userId)) {
            // 校验sql执行权限
            this.checkAuthority(context, statement, sparkTempTables);
        }
    }

    /**
     * 校验用户是否用执行sql语句权限
     * 创建表:
     * 1. 检测用户是否数据工作空间成员
     * 删除表:
     * 1. 检测用户是否为表owner
     * 修改表 & 添加列 & 修改列:
     * 1. 检测用户是否为表owner
     * 查询表:
     * 1. select table 是否为owner,或者有insert权限
     * 写入表:
     * 1. insertInstanceDependent into table 是否为owner,或者有insert权限
     * 2. select table 是否为owner,或者有insert权限
     */
    @Transactional(rollbackFor = Exception.class)
    public void checkAuthority(AuthContext context, Statement statement, String[] sparkTempTables) {
        final StatementType statementType = statement.getStatementType();
        if (SHOW == statementType) {
            return;
        }
        final PrivilegeType privilegeType = statement.getPrivilegeType();

        if (CREATE_TABLE == statementType || CREATE_TABLE_AS_LIKE == statementType) {
            CreateTable createTable = (CreateTable) statement;
            checkAccessTableAuth(context, createTable.getTableId(), privilegeType);
        } else if (CREATE_VIEW == statementType) {
            CreateView createView = (CreateView) statement;
            checkAccessTableAuth(context, createView.getTableId(), privilegeType);
            QueryStmt queryStmt = createView.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
        } else if (CREATE_TABLE_AS_SELECT == statementType) {
            CreateTableAsSelect tableAsSelect = (CreateTableAsSelect) statement;
            checkAccessTableAuth(context, tableAsSelect.getTableId(), privilegeType);
            QueryStmt queryStmt = tableAsSelect.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (DROP_TABLE == statementType) {
            DropTable dropTable = (DropTable) statement;
            TableId tableId = dropTable.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType, dropTable.getIfExists());
        } else if (DROP_VIEW == statementType) {
            DropView view = (DropView) statement;
            TableId tableId = view.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType, view.getIfExists());
        } else if (TRUNCATE_TABLE == statementType) {
            TruncateTable truncateTable = (TruncateTable) statement;
            TableId tableId = truncateTable.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType);
        } else if (SELECT == statementType) {
            QueryStmt queryStmt = (QueryStmt) statement;
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (INSERT == statementType) { // 多路输出
            InsertTable multiInsertStmt = (InsertTable) statement;
            QueryStmt queryStmt = multiInsertStmt.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());

            for (TableId tableId : multiInsertStmt.getOutputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }

            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (DELETE == statementType) {
            DeleteTable deleteTable = (DeleteTable) statement;
            checkAccessTableAuth(context, deleteTable.getTableId(), privilegeType);
        } else if (UPDATE == statementType) {
            UpdateTable updateTable = (UpdateTable) statement;
            checkAccessTableAuth(context, updateTable.getTableId(), privilegeType);
        } else if (MERGE == statementType) {
            MergeTable mergeIntoTable = (MergeTable) statement;
            mergeIntoTable.getInputTables().forEach(tableId -> {
                boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);

                if (!sparkTempTable) {
                    checkAccessTableAuth(context, tableId, privilegeType);
                }
            });

            checkAccessTableAuth(context, mergeIntoTable.getTargetTable(), privilegeType);
        } else if (EXPORT_TABLE == statementType) { // export table
            ExportTable tableData = (ExportTable) statement;
            for (TableId tableId : tableData.getInputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }
            checkFunctionAuth(context, tableData.getFunctionNames());
        } else if (DATATUNNEL == statementType && statement instanceof DataTunnelExpr) {
            DataTunnelExpr dataTunnelExpr = (DataTunnelExpr) statement;
            for (TableId tableId : dataTunnelExpr.getInputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }

            long tenantId = context.getTenantId();
            String regionCode = context.getRegionCode();
            String userId = context.getUserId();
            checkDatatunnelAuthority(tenantId, regionCode, userId, dataTunnelExpr);
            checkFunctionAuth(context, dataTunnelExpr.getFunctionNames());
        } else if (CACHE == statementType) { // spark cache
            CacheTable cacheTable = (CacheTable) statement;
            if (cacheTable.getQueryStmt() != null) {
                QueryStmt queryStmt = cacheTable.getQueryStmt();
                this.checkAuthority(context, queryStmt, sparkTempTables);
            } else {
                checkAccessTableAuth(context, cacheTable.getTableId(), privilegeType);
            }

        } else if (ALTER_TABLE == statementType) {
            checkAlterTableAuthority(context, statement, sparkTempTables, privilegeType);
        } else if (CALL == statementType) {
            CallProcedure procedure = (CallProcedure) statement;
            if (procedure.getProperties().containsKey("table")) {
                String tableId = procedure.getProperties().get("table");
                String[] items = StringUtils.split(tableId, ".");
                String databaseName = null;
                String tableName = null;
                if (items.length == 1) {
                    tableName = items[0];
                } else if (items.length == 2) {
                    databaseName = items[0];
                    tableName = items[1];
                } else {
                    throw new SuperiorException("Unsupported identifier " + tableId);
                }
                checkAdminTableAuthority(context, databaseName, tableName);
            }
        }
    }

    private void checkAdminTableAuthority(AuthContext context, String workspaceCode, String tableName) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        String catalogName = workspaceService.getCatalogName(tenantId, regionCode, workspaceCode, currentCatalog);
        String databaseName = CommonUtils.getDatabaseName(workspaceCode, currentDatabase);
        tableName = StringUtils.lowerCase(tableName);

        TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
        if (table == null) {
            String msg = String.format("table not exist: %s.%s.%s", catalogName, databaseName, tableName);
            throw new AccessTableException(msg);
        }
        if (!tablePrivsService.checkOwner(table, userId)) {
            String msg = String.format("%s 不是表owner: %s.%s.%s", userId, catalogName, databaseName, tableName);
            throw new AccessTableException(msg);
        }
    }

    private void checkDatatunnelAuthority(
            Long tenantId, String regionCode, String userId, DataTunnelExpr dataTunnelExpr) {
        Map<String, Object> sourceOptions = dataTunnelExpr.getSourceOptions();
        if (sourceOptions.containsKey("datasource")) {
            String datasource = (String) sourceOptions.get("datasource");
            String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
            String schemaName = (String) sourceOptions.getOrDefault("databaseName", null);
            if (schemaName == null) {
                schemaName = (String) sourceOptions.get("schemaName");
            }
            String tableName = (String) sourceOptions.get("tableName");

            AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
            if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
                checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.READ);
            }
        }

        Map<String, Object> sinkOptions = dataTunnelExpr.getSinkOptions();
        if (sinkOptions.containsKey("datasource")) {
            String datasource = (String) sinkOptions.get("datasource");
            String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
            String schemaName = (String) sinkOptions.getOrDefault("databaseName", null);
            if (schemaName == null) {
                schemaName = (String) sinkOptions.get("schemaName");
            }
            String tableName = (String) sinkOptions.get("tableName");

            AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
            if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
                checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.WRITE);
            }
        }
    }

    private void checkAlterTableAuthority(
            AuthContext context, Statement statement, String[] sparkTempTables, PrivilegeType privilegeType) {
        AlterTable alterTable = (AlterTable) statement;
        TableId tableId = alterTable.getTableId();

        checkAccessTableAuth(context, tableId, privilegeType, alterTable.getIfExists());

        AlterActionType alterType = alterTable.getFirstAlterType();
        if (AlterActionType.ALTER_VIEW_QUERY == alterType) {
            AlterViewAction view = (AlterViewAction) alterTable.firstAction();
            QueryStmt queryStmt = view.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
        }
    }

    /**
     * 记录用户查询表信息
     */
    private void addTabAccessLog(AuthContext context, List<TableId> inputTables) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        UserInfoEntity userInfoEntity = userInfoService.queryUser(tenantId, userId);
        if (userInfoEntity == null) {
            return;
        }
        for (TableId table : inputTables) {
            String userName = userInfoEntity.getCnName();
            String catalogName = CommonUtils.getCatalogName(currentCatalog, table.getCatalogName());
            String databaseName = CommonUtils.getDatabaseName(currentDatabase, table.getSchemaName());
            String tableName = table.getTableName().toLowerCase();
            this.tableAccessLogService.update(
                    tenantId, regionCode, catalogName, databaseName, tableName, userName, userId);
        }
    }

    private void checkSelectTableAuth(AuthContext context, List<TableId> inputTables, String[] sparkTempTables) {
        DataSourceType dataSourceType = context.getDataSourceType();
        for (TableId tableId : inputTables) {
            if (ORACLE == dataSourceType || DAMENG == dataSourceType || OCEANBASE == dataSourceType) {
                if (StringUtils.equalsIgnoreCase("dual", tableId.getTableName())) {
                    continue;
                }
            }

            boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);

            if (!sparkTempTable) {
                checkAccessTableAuth(context, tableId, PrivilegeType.READ);
            }
        }
    }

    private void checkFunctionAuth(AuthContext context, HashSet<FunctionId> functionNames) {
        if (functionNames == null) {
            return;
        }

        LOG.info("context: {} function names: {}", context, StringUtils.join(functionNames, ","));

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        for (FunctionId functionId : functionNames) {
            String database = functionId.getSchemaName();
            String funcName = functionId.getFunctionName();
            database = database == null ? context.getCurrentDatabase() : database;

            database = StringUtils.lowerCase(database);
            funcName = StringUtils.lowerCase(funcName);
            FunctionEntity function =
                    functionService.queryFunction(tenantId, regionCode, currentCatalog, database, funcName);
            if (function != null) {
                if (AUTH_ONESELF.equals(function.getAuthType())) {
                    if (!userId.equals(function.getCreater())) {
                        throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:仅个人可用", funcName);
                    }
                } else if (AUTH_WORKSPACE_USERS.equals(function.getAuthType())) {
                    if (StringUtils.isNotBlank(database) && !currentDatabase.equals(database)) {
                        throw new AccessFunctionException("无权访问函数: , 函数访问范围:仅项目空间成员可用", funcName);
                    }
                } else if (AUTH_ASSIGN_USERS.equals(function.getAuthType())) {
                    if (!function.getAuthUsers().contains(userId)) {
                        throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:指定用户可用", funcName);
                    }
                }
            }
        }
    }

    private boolean isSparkTempTable(String currentDatabaseName, String[] sparkTempTables, TableId tableId) {
        boolean sparkTempTable = false;

        if (sparkTempTables != null && sparkTempTables.length > 0) {
            String schemaName = CommonUtils.getDatabaseName(currentDatabaseName, tableId.getSchemaName());
            String tableName = StringUtils.lowerCase(tableId.getTableName());

            if (StringUtils.equalsIgnoreCase(currentDatabaseName, schemaName)) {
                for (String name : sparkTempTables) {
                    if (StringUtils.equalsIgnoreCase(name, tableName)) {
                        sparkTempTable = true;
                        break;
                    }
                }
            }
        }
        return sparkTempTable;
    }

    @Transactional
    public void checkAccessTableAuth(AuthContext authContext, TableId tableId, PrivilegeType privilegeType) {
        this.checkAccessTableAuth(authContext, tableId, privilegeType, false);
    }

    private void checkAccessTableAuth(
            AuthContext context, TableId tableId, PrivilegeType privilegeType, boolean ifExists) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String catalogName = CommonUtils.getCatalogName(context.getCurrentCatalog(), tableId.getCatalogName());
        String databaseName = CommonUtils.getDatabaseName(context.getCurrentDatabase(), tableId.getSchemaName());
        String tableName = tableId.getTableName();

        // 访问paimon 系统表。格式:SELECT * FROM hive_metastore.bigdata.paimon_users_ods$schemas
        String paimonSysTableName = StringUtils.substringAfterLast(tableName, "$");
        if (StringUtils.isNotBlank(paimonSysTableName)
                && ArrayUtils.contains(PAIMON_SYS_TABLES, paimonSysTableName.toLowerCase())) {
            tableName = StringUtils.substringBeforeLast(tableName, "$");
        }

        if (privilegeType != PrivilegeType.CREATE) {
            TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
            if (table == null) {
                if (ifExists) {
                    return; // alter table if exists 语句,找不到表,不报错
                }
                throw new AccessTableException("table not exist: {}.{}.{}", catalogName, databaseName, tableName);
            }
            if (tablePrivsService.checkOwner(table, userId)) {
                return;
            }
        }

        if (StringUtils.isBlank(databaseName)) {
            throw new SuperiorException("databaseName can not blank");
        }

        List<SecTablePrivsEntity> tablePrivsList =
                tablePrivsService.queryTablePrivs(tenantId, userId, catalogName, databaseName, tableName);
        List<SecDatabasePrivsEntity> databasePrivsList = null;
        if (tablePrivsList.isEmpty()) {
            databasePrivsList = databasePrivsService.queryDatabasePrivs(tenantId, userId, catalogName, databaseName);
            if (databasePrivsList.isEmpty()) {
                throw new AccessTableException(
                        "{} 没有申请表: {}.{}.{} 使用权限: {}", userId, catalogName, databaseName, tableName, privilegeType);
            }
        }

        for (SecTablePrivsEntity tablePrivs : tablePrivsList) { // 权限审核通过
            LocalDate currentDate = LocalDate.now();
            // 权限过期
            if (currentDate.isAfter(tablePrivs.getExpireDate())) {
                if (tablePrivs.getStatus() == 1 || tablePrivs.getStatus() == 15) {
                    tablePrivs.setStatus(9);
                    tablePrivsService.updateEntity(tablePrivs);
                }
            }

            switch (privilegeType) {
                case READ:
                    if (tablePrivs.isReadPriv()) {
                        return;
                    }
                    break;
                case WRITE:
                    if (tablePrivs.isWritePriv()) {
                        return;
                    }
                    break;
                case ALTER:
                    if (tablePrivs.isAlterPriv()) {
                        return;
                    }
                    break;
                case DROP:
                    if (tablePrivs.isDropPriv()) {
                        return;
                    }
                    break;
                default:
                    throw new AccessTableException("not support " + privilegeType);
            }
        }

        if (databasePrivsList != null && !databasePrivsList.isEmpty()) {
            for (SecDatabasePrivsEntity databasePrivs : databasePrivsList) { // 权限审核通过
                LocalDate currentDate = LocalDate.now();
                // 权限过期
                if (currentDate.isAfter(databasePrivs.getExpireDate())) {
                    if (databasePrivs.getStatus() == 15) {
                        databasePrivs.setStatus(9);
                        databasePrivsService.updateEntity(databasePrivs);
                    }
                }

                switch (privilegeType) {
                    case CREATE:
                        if (databasePrivs.isCreatePriv()) {
                            return;
                        }
                        break;
                    case READ:
                        if (databasePrivs.isReadPriv()) {
                            return;
                        }
                        break;
                    case WRITE:
                        if (databasePrivs.isWritePriv()) {
                            return;
                        }
                        break;
                    case ALTER:
                        if (databasePrivs.isAlterPriv()) {
                            return;
                        }
                        break;
                    case DROP:
                        if (databasePrivs.isDropPriv()) {
                            return;
                        }
                        break;
                    default:
                        throw new AccessTableException("not support " + privilegeType);
                }
            }
        }

        String msg = String.format(
                "%s 没有申请表: %s.%s.%s 使用权限: %s", userId, catalogName, databaseName, tableName, privilegeType);
        throw new AccessTableException(msg);
    }
}

melin avatar Jul 25 '24 10:07 melin