superior-sql-parser
superior-sql-parser copied to clipboard
这个项目支持校验某个用户对某个表是否具有增删改权限吗?
只是解析sql,可以结合解析出来的信息,自己实现权限校验。例如我们的实现:
/**
* @TODO
* 1、支持存储过程权限校验
* 2、支持flink cdc 权限校验
*/
@Service
public class AuthorizationService {
private static final Logger LOG = LoggerFactory.getLogger(AuthorizationService.class);
@Autowired
private TableAccessLogService tableAccessLogService;
@Autowired
private UserInfoService userInfoService;
@Autowired
private SuperiorBeeConfigClient configClient;
@Autowired
private FunctionService functionService;
@Autowired
private DataSourceService dataSourceService;
@Autowired
private TableService tableService;
@Autowired
private WorkspaceService workspaceService;
@Autowired
private SecTablePrivsService tablePrivsService;
@Autowired
private SecDatabasePrivsService databasePrivsService;
@Transactional
public void checkAuthority(AuthContext context, String sql, String[] sparkTempTables) {
String userId = context.getUserId();
if (StringUtils.isBlank(context.getUserId())) {
throw new IllegalArgumentException("userId can not empty");
}
Statement statement = SparkSqlHelper.parseStatement(sql);
boolean supportedSql = SparkSqlHelper.checkSupportedSQL(statement.getStatementType());
if (!supportedSql) {
throw new SQLParserException("not support sql: " + sql);
}
String[] superTableOwners = configClient.getStringArray(SuperiorConf.SUPERIOR_SKIP_TABLE_AUTH_CHECK_USERS);
if (!ArrayUtils.contains(superTableOwners, userId)) {
// 校验sql执行权限
this.checkAuthority(context, statement, sparkTempTables);
}
}
/**
* 校验用户是否用执行sql语句权限
* 创建表:
* 1. 检测用户是否数据工作空间成员
* 删除表:
* 1. 检测用户是否为表owner
* 修改表 & 添加列 & 修改列:
* 1. 检测用户是否为表owner
* 查询表:
* 1. select table 是否为owner,或者有insert权限
* 写入表:
* 1. insertInstanceDependent into table 是否为owner,或者有insert权限
* 2. select table 是否为owner,或者有insert权限
*/
@Transactional(rollbackFor = Exception.class)
public void checkAuthority(AuthContext context, Statement statement, String[] sparkTempTables) {
final StatementType statementType = statement.getStatementType();
if (SHOW == statementType) {
return;
}
final PrivilegeType privilegeType = statement.getPrivilegeType();
if (CREATE_TABLE == statementType || CREATE_TABLE_AS_LIKE == statementType) {
CreateTable createTable = (CreateTable) statement;
checkAccessTableAuth(context, createTable.getTableId(), privilegeType);
} else if (CREATE_VIEW == statementType) {
CreateView createView = (CreateView) statement;
checkAccessTableAuth(context, createView.getTableId(), privilegeType);
QueryStmt queryStmt = createView.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
} else if (CREATE_TABLE_AS_SELECT == statementType) {
CreateTableAsSelect tableAsSelect = (CreateTableAsSelect) statement;
checkAccessTableAuth(context, tableAsSelect.getTableId(), privilegeType);
QueryStmt queryStmt = tableAsSelect.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (DROP_TABLE == statementType) {
DropTable dropTable = (DropTable) statement;
TableId tableId = dropTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, dropTable.getIfExists());
} else if (DROP_VIEW == statementType) {
DropView view = (DropView) statement;
TableId tableId = view.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, view.getIfExists());
} else if (TRUNCATE_TABLE == statementType) {
TruncateTable truncateTable = (TruncateTable) statement;
TableId tableId = truncateTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType);
} else if (SELECT == statementType) {
QueryStmt queryStmt = (QueryStmt) statement;
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (INSERT == statementType) { // 多路输出
InsertTable multiInsertStmt = (InsertTable) statement;
QueryStmt queryStmt = multiInsertStmt.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
for (TableId tableId : multiInsertStmt.getOutputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (DELETE == statementType) {
DeleteTable deleteTable = (DeleteTable) statement;
checkAccessTableAuth(context, deleteTable.getTableId(), privilegeType);
} else if (UPDATE == statementType) {
UpdateTable updateTable = (UpdateTable) statement;
checkAccessTableAuth(context, updateTable.getTableId(), privilegeType);
} else if (MERGE == statementType) {
MergeTable mergeIntoTable = (MergeTable) statement;
mergeIntoTable.getInputTables().forEach(tableId -> {
boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);
if (!sparkTempTable) {
checkAccessTableAuth(context, tableId, privilegeType);
}
});
checkAccessTableAuth(context, mergeIntoTable.getTargetTable(), privilegeType);
} else if (EXPORT_TABLE == statementType) { // export table
ExportTable tableData = (ExportTable) statement;
for (TableId tableId : tableData.getInputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
checkFunctionAuth(context, tableData.getFunctionNames());
} else if (DATATUNNEL == statementType && statement instanceof DataTunnelExpr) {
DataTunnelExpr dataTunnelExpr = (DataTunnelExpr) statement;
for (TableId tableId : dataTunnelExpr.getInputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
checkDatatunnelAuthority(tenantId, regionCode, userId, dataTunnelExpr);
checkFunctionAuth(context, dataTunnelExpr.getFunctionNames());
} else if (CACHE == statementType) { // spark cache
CacheTable cacheTable = (CacheTable) statement;
if (cacheTable.getQueryStmt() != null) {
QueryStmt queryStmt = cacheTable.getQueryStmt();
this.checkAuthority(context, queryStmt, sparkTempTables);
} else {
checkAccessTableAuth(context, cacheTable.getTableId(), privilegeType);
}
} else if (ALTER_TABLE == statementType) {
checkAlterTableAuthority(context, statement, sparkTempTables, privilegeType);
} else if (CALL == statementType) {
CallProcedure procedure = (CallProcedure) statement;
if (procedure.getProperties().containsKey("table")) {
String tableId = procedure.getProperties().get("table");
String[] items = StringUtils.split(tableId, ".");
String databaseName = null;
String tableName = null;
if (items.length == 1) {
tableName = items[0];
} else if (items.length == 2) {
databaseName = items[0];
tableName = items[1];
} else {
throw new SuperiorException("Unsupported identifier " + tableId);
}
checkAdminTableAuthority(context, databaseName, tableName);
}
}
}
private void checkAdminTableAuthority(AuthContext context, String workspaceCode, String tableName) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
String catalogName = workspaceService.getCatalogName(tenantId, regionCode, workspaceCode, currentCatalog);
String databaseName = CommonUtils.getDatabaseName(workspaceCode, currentDatabase);
tableName = StringUtils.lowerCase(tableName);
TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
if (table == null) {
String msg = String.format("table not exist: %s.%s.%s", catalogName, databaseName, tableName);
throw new AccessTableException(msg);
}
if (!tablePrivsService.checkOwner(table, userId)) {
String msg = String.format("%s 不是表owner: %s.%s.%s", userId, catalogName, databaseName, tableName);
throw new AccessTableException(msg);
}
}
private void checkDatatunnelAuthority(
Long tenantId, String regionCode, String userId, DataTunnelExpr dataTunnelExpr) {
Map<String, Object> sourceOptions = dataTunnelExpr.getSourceOptions();
if (sourceOptions.containsKey("datasource")) {
String datasource = (String) sourceOptions.get("datasource");
String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
String schemaName = (String) sourceOptions.getOrDefault("databaseName", null);
if (schemaName == null) {
schemaName = (String) sourceOptions.get("schemaName");
}
String tableName = (String) sourceOptions.get("tableName");
AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.READ);
}
}
Map<String, Object> sinkOptions = dataTunnelExpr.getSinkOptions();
if (sinkOptions.containsKey("datasource")) {
String datasource = (String) sinkOptions.get("datasource");
String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
String schemaName = (String) sinkOptions.getOrDefault("databaseName", null);
if (schemaName == null) {
schemaName = (String) sinkOptions.get("schemaName");
}
String tableName = (String) sinkOptions.get("tableName");
AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.WRITE);
}
}
}
private void checkAlterTableAuthority(
AuthContext context, Statement statement, String[] sparkTempTables, PrivilegeType privilegeType) {
AlterTable alterTable = (AlterTable) statement;
TableId tableId = alterTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, alterTable.getIfExists());
AlterActionType alterType = alterTable.getFirstAlterType();
if (AlterActionType.ALTER_VIEW_QUERY == alterType) {
AlterViewAction view = (AlterViewAction) alterTable.firstAction();
QueryStmt queryStmt = view.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
}
}
/**
* 记录用户查询表信息
*/
private void addTabAccessLog(AuthContext context, List<TableId> inputTables) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
UserInfoEntity userInfoEntity = userInfoService.queryUser(tenantId, userId);
if (userInfoEntity == null) {
return;
}
for (TableId table : inputTables) {
String userName = userInfoEntity.getCnName();
String catalogName = CommonUtils.getCatalogName(currentCatalog, table.getCatalogName());
String databaseName = CommonUtils.getDatabaseName(currentDatabase, table.getSchemaName());
String tableName = table.getTableName().toLowerCase();
this.tableAccessLogService.update(
tenantId, regionCode, catalogName, databaseName, tableName, userName, userId);
}
}
private void checkSelectTableAuth(AuthContext context, List<TableId> inputTables, String[] sparkTempTables) {
DataSourceType dataSourceType = context.getDataSourceType();
for (TableId tableId : inputTables) {
if (ORACLE == dataSourceType || DAMENG == dataSourceType || OCEANBASE == dataSourceType) {
if (StringUtils.equalsIgnoreCase("dual", tableId.getTableName())) {
continue;
}
}
boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);
if (!sparkTempTable) {
checkAccessTableAuth(context, tableId, PrivilegeType.READ);
}
}
}
private void checkFunctionAuth(AuthContext context, HashSet<FunctionId> functionNames) {
if (functionNames == null) {
return;
}
LOG.info("context: {} function names: {}", context, StringUtils.join(functionNames, ","));
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
for (FunctionId functionId : functionNames) {
String database = functionId.getSchemaName();
String funcName = functionId.getFunctionName();
database = database == null ? context.getCurrentDatabase() : database;
database = StringUtils.lowerCase(database);
funcName = StringUtils.lowerCase(funcName);
FunctionEntity function =
functionService.queryFunction(tenantId, regionCode, currentCatalog, database, funcName);
if (function != null) {
if (AUTH_ONESELF.equals(function.getAuthType())) {
if (!userId.equals(function.getCreater())) {
throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:仅个人可用", funcName);
}
} else if (AUTH_WORKSPACE_USERS.equals(function.getAuthType())) {
if (StringUtils.isNotBlank(database) && !currentDatabase.equals(database)) {
throw new AccessFunctionException("无权访问函数: , 函数访问范围:仅项目空间成员可用", funcName);
}
} else if (AUTH_ASSIGN_USERS.equals(function.getAuthType())) {
if (!function.getAuthUsers().contains(userId)) {
throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:指定用户可用", funcName);
}
}
}
}
}
private boolean isSparkTempTable(String currentDatabaseName, String[] sparkTempTables, TableId tableId) {
boolean sparkTempTable = false;
if (sparkTempTables != null && sparkTempTables.length > 0) {
String schemaName = CommonUtils.getDatabaseName(currentDatabaseName, tableId.getSchemaName());
String tableName = StringUtils.lowerCase(tableId.getTableName());
if (StringUtils.equalsIgnoreCase(currentDatabaseName, schemaName)) {
for (String name : sparkTempTables) {
if (StringUtils.equalsIgnoreCase(name, tableName)) {
sparkTempTable = true;
break;
}
}
}
}
return sparkTempTable;
}
@Transactional
public void checkAccessTableAuth(AuthContext authContext, TableId tableId, PrivilegeType privilegeType) {
this.checkAccessTableAuth(authContext, tableId, privilegeType, false);
}
private void checkAccessTableAuth(
AuthContext context, TableId tableId, PrivilegeType privilegeType, boolean ifExists) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String catalogName = CommonUtils.getCatalogName(context.getCurrentCatalog(), tableId.getCatalogName());
String databaseName = CommonUtils.getDatabaseName(context.getCurrentDatabase(), tableId.getSchemaName());
String tableName = tableId.getTableName();
// 访问paimon 系统表。格式:SELECT * FROM hive_metastore.bigdata.paimon_users_ods$schemas
String paimonSysTableName = StringUtils.substringAfterLast(tableName, "$");
if (StringUtils.isNotBlank(paimonSysTableName)
&& ArrayUtils.contains(PAIMON_SYS_TABLES, paimonSysTableName.toLowerCase())) {
tableName = StringUtils.substringBeforeLast(tableName, "$");
}
if (privilegeType != PrivilegeType.CREATE) {
TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
if (table == null) {
if (ifExists) {
return; // alter table if exists 语句,找不到表,不报错
}
throw new AccessTableException("table not exist: {}.{}.{}", catalogName, databaseName, tableName);
}
if (tablePrivsService.checkOwner(table, userId)) {
return;
}
}
if (StringUtils.isBlank(databaseName)) {
throw new SuperiorException("databaseName can not blank");
}
List<SecTablePrivsEntity> tablePrivsList =
tablePrivsService.queryTablePrivs(tenantId, userId, catalogName, databaseName, tableName);
List<SecDatabasePrivsEntity> databasePrivsList = null;
if (tablePrivsList.isEmpty()) {
databasePrivsList = databasePrivsService.queryDatabasePrivs(tenantId, userId, catalogName, databaseName);
if (databasePrivsList.isEmpty()) {
throw new AccessTableException(
"{} 没有申请表: {}.{}.{} 使用权限: {}", userId, catalogName, databaseName, tableName, privilegeType);
}
}
for (SecTablePrivsEntity tablePrivs : tablePrivsList) { // 权限审核通过
LocalDate currentDate = LocalDate.now();
// 权限过期
if (currentDate.isAfter(tablePrivs.getExpireDate())) {
if (tablePrivs.getStatus() == 1 || tablePrivs.getStatus() == 15) {
tablePrivs.setStatus(9);
tablePrivsService.updateEntity(tablePrivs);
}
}
switch (privilegeType) {
case READ:
if (tablePrivs.isReadPriv()) {
return;
}
break;
case WRITE:
if (tablePrivs.isWritePriv()) {
return;
}
break;
case ALTER:
if (tablePrivs.isAlterPriv()) {
return;
}
break;
case DROP:
if (tablePrivs.isDropPriv()) {
return;
}
break;
default:
throw new AccessTableException("not support " + privilegeType);
}
}
if (databasePrivsList != null && !databasePrivsList.isEmpty()) {
for (SecDatabasePrivsEntity databasePrivs : databasePrivsList) { // 权限审核通过
LocalDate currentDate = LocalDate.now();
// 权限过期
if (currentDate.isAfter(databasePrivs.getExpireDate())) {
if (databasePrivs.getStatus() == 15) {
databasePrivs.setStatus(9);
databasePrivsService.updateEntity(databasePrivs);
}
}
switch (privilegeType) {
case CREATE:
if (databasePrivs.isCreatePriv()) {
return;
}
break;
case READ:
if (databasePrivs.isReadPriv()) {
return;
}
break;
case WRITE:
if (databasePrivs.isWritePriv()) {
return;
}
break;
case ALTER:
if (databasePrivs.isAlterPriv()) {
return;
}
break;
case DROP:
if (databasePrivs.isDropPriv()) {
return;
}
break;
default:
throw new AccessTableException("not support " + privilegeType);
}
}
}
String msg = String.format(
"%s 没有申请表: %s.%s.%s 使用权限: %s", userId, catalogName, databaseName, tableName, privilegeType);
throw new AccessTableException(msg);
}
}