oneflow
oneflow copied to clipboard
support okl dialect
ExtractKernelLaunchTensor Pass的定义: 在wrap-kernel-launch pass之后的ir是通过纯tensor流传递数据,该pass将!okl.launcher_ctx引入数据流,成为okl抽象层面上数据流的实际管理者。 引入!okl.get_tensor_from_ctx op产生不同用途的tensor,相应的用途通过!okl.tensortypeEnumAttr进行标定,以保留更多的抽象信息。
WrapOpsToKernelLaunchPass定义: 将oneflow::job内用于计算的连续ops打包成一个func,然后将func汇编塞入oneflow.kernel_launch的单个op里面,并妥善处理连续op的tensor流与该func args return的问题。 比如: job(0){ 1 = oneflow.relu(0) 2 = oneflow.relu(1) return 2 } 转换成 job(0){ x = oneflow.kernel_launch(0) @{"1=oneflow.relu(0) 2=onefow.relu(1) return 2"}(tensor -> tensor) return 2 }
RoundTrip 第二个Pass最后生成:
module {
oneflow.job @GraphToRun_0(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32>
%0 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%1 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%output_0 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0_2", output_lbns = ["_GraphToRun_0_output.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32>
oneflow.return %output_0 : tensor<1xf32>
}
}
-wrap-ops-to-kernel-launch
- 该Pass通过将连续在一起、用于计算的oneflow ops打包成一个函数,从而提供对oneflow op运算过程在MLIR的抽象提供基础功能支持
- 主要处理计算型ops的集体打包,形参与输出的合并与消除
module {
func.func @wrap0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) attributes {llvm.emit_c_interface} {
%0 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%1 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
return %0, %1 : tensor<1xf32>, tensor<1xf32>
}
oneflow.job @GraphToRun_0(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_input.0.0_2", output_lbns = ["_GraphToRun_0_input.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32>
%0:2 = oneflow.kernel_launch @wrap0(%output, %output) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], mlir_assembly = "\22func.func\22() ({\0A^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>):\0A %0 = \22oneflow.relu\22(%arg1) {device_name = [\22@0:0\22], device_tag = \22cpu\22, hierarchy = [1], op_name = \22relu-0\22, scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>\0A %1 = \22oneflow.relu\22(%arg1) {device_name = [\22@0:0\22], device_tag = \22cpu\22, hierarchy = [1], op_name = \22relu-0\22, scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>\0A \22func.return\22(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> ()\0A}) {function_type = (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>), llvm.emit_c_interface, sym_name = \22wrap0\22} : () -> ()", op_name = "wrap0", scope_symbol_id = 12 : i64} : (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>)
%output_0 = "oneflow.output"(%0#1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_0_output.0.0_2", output_lbns = ["_GraphToRun_0_output.0.0_2/out"], scope_symbol_id = 12 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32>
oneflow.return %output_0 : tensor<1xf32>
}
}
在 oneflow.kernel_launch 的compute函数中,对mlir_assembly进行解析获取:
module {
func.func @wrap0(%arg0: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) attributes {compiled = "true", llvm.emit_c_interface} {
%0 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%1 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
return %0, %1 : tensor<1xf32>, tensor<1xf32>
}
}
-extract-kernel-launch-tensor
- 该pass将引入!okl.launcher_ctx的抽象,该抽象用于管理整个函数运行时所需要的资源
- 该pass可以将oneflow内的资源和MLIR的tensor概念进行转换
- 该pass的outcome可能会通过其他pass获取到专门对 !okl.launcher_ctx更具体的抽象
module {
func.func @wrap0(%arg0: !okl.launcher_ctx) -> (tensor<1xf32>, tensor<1xf32>) {
%0 = "okl.get_tensor_from_arg"(%arg0) {tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<1xf32>
%1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%2 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%3 = "okl.get_tensor_as_ret"(%arg0, %1) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
%4 = "okl.get_tensor_as_ret"(%arg0, %2) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
return %3, %4 : tensor<1xf32>, tensor<1xf32>
}
}
-trim-return-as-void
- 该pass为通用pass,该pass将拥有输出参数的函数转换成void型的函数
- 由于kernel的compute的结果是void,所以通过该pass进行抽象上的统一
module {
func.func @wrap0(%arg0: !okl.launcher_ctx) {
%0 = "okl.get_tensor_from_arg"(%arg0) {tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<1xf32>
%1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%2 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%3 = "okl.get_tensor_as_ret"(%arg0, %1) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
%4 = "okl.get_tensor_as_ret"(%arg0, %2) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
return
}
}
-lower-oneflow-to-okl
- 将oneflow dialect转换成okl
- TODO:这部分可能还需要继续抽象拆分Pass
module {
func.func @_mlir__mlir_ciface_okl_func(%arg0: !okl.launcher_ctx) attributes {compiled = "true"} {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.source_reg_ctx %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>, sym_name = "relu-0"} : () -> !okl.reg_ctx
%1 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.build_op_kernel"(%0) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%0, %1, %2) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%0) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%1) : (!okl.run_ctx) -> ()
%3 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.source_reg_ctx %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>, sym_name = "relu-0"} : () -> !okl.reg_ctx
%4 = "okl.build_run_ctx"(%3, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%5 = "okl.build_op_kernel"(%3) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%3, %4, %5) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%3) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%4) : (!okl.run_ctx) -> ()
return
}
}
module {
func.func @_mlir__mlir_ciface_okl_func(%arg0: !okl.launcher_ctx) attributes {compiled = "true"} {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%1 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.build_op_kernel"(%0) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%0, %1, %2) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%0) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%1) : (!okl.run_ctx) -> ()
%3 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%4 = "okl.build_run_ctx"(%3, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%5 = "okl.build_op_kernel"(%3) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%3, %4, %5) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%3) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%4) : (!okl.run_ctx) -> ()
return
}
}
-split-into-funcs
module {
func.func @okl_recycle(%arg0: !okl.launcher_ctx) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0:3 = call @okl_get_resources(%arg0) : (!okl.launcher_ctx) -> (tensor<2x!okl.reg_ctx>, tensor<2x!okl.run_ctx>, tensor<2x!okl.kernel>)
%1 = tensor.extract %0#0[%c0] : tensor<2x!okl.reg_ctx>
%2 = tensor.extract %0#0[%c1] : tensor<2x!okl.reg_ctx>
%3 = tensor.extract %0#1[%c0] : tensor<2x!okl.run_ctx>
%4 = tensor.extract %0#1[%c1] : tensor<2x!okl.run_ctx>
"okl.destroy_reg_ctx"(%1) : (!okl.reg_ctx) -> ()
"okl.destroy_reg_ctx"(%2) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%3) : (!okl.run_ctx) -> ()
"okl.destroy_run_ctx"(%4) : (!okl.run_ctx) -> ()
return
}
func.func @okl_compute(%arg0: !okl.launcher_ctx) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0:3 = call @okl_get_resources(%arg0) : (!okl.launcher_ctx) -> (tensor<2x!okl.reg_ctx>, tensor<2x!okl.run_ctx>, tensor<2x!okl.kernel>)
%1 = tensor.extract %0#0[%c0] : tensor<2x!okl.reg_ctx>
%2 = tensor.extract %0#0[%c1] : tensor<2x!okl.reg_ctx>
%3 = tensor.extract %0#1[%c0] : tensor<2x!okl.run_ctx>
%4 = tensor.extract %0#1[%c1] : tensor<2x!okl.run_ctx>
%5 = tensor.extract %0#2[%c0] : tensor<2x!okl.kernel>
%6 = tensor.extract %0#2[%c1] : tensor<2x!okl.kernel>
"okl.launch"(%1, %3, %5) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%2, %4, %6) : (!okl.reg_ctx, !okl.run_ctx, !okl.kernel) -> ()
return
}
func.func @okl_get_resources(%arg0: !okl.launcher_ctx) -> (tensor<2x!okl.reg_ctx>, tensor<2x!okl.run_ctx>, tensor<2x!okl.kernel>) {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%9 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %9 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%1 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%9 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %9 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%2 = tensor.from_elements %0, %1 : tensor<2x!okl.reg_ctx>
%3 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%4 = "okl.build_run_ctx"(%1, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%5 = tensor.from_elements %3, %4 : tensor<2x!okl.run_ctx>
%6 = "okl.build_op_kernel"(%0) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
%7 = "okl.build_op_kernel"(%1) {op_type_name = "relu"} : (!okl.reg_ctx) -> !okl.kernel
%8 = tensor.from_elements %6, %7 : tensor<2x!okl.kernel>
return %2, %5, %8 : tensor<2x!okl.reg_ctx>, tensor<2x!okl.run_ctx>, tensor<2x!okl.kernel>
}
}
生成的 okl_get_resources函数用来初始化kernel launch的状态, 生成的 okl_compute用来lower到llvm被jit执行。 -lower-to-llvm-func
module {
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.call @fetch_run_ctx(%arg0, %1) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%3 = llvm.mlir.constant(1 : index) : i64
%4 = llvm.call @fetch_run_ctx(%arg0, %3) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%5 = llvm.mlir.constant(0 : index) : i64
%6 = llvm.call @fetch_kernel(%arg0, %5) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%7 = llvm.mlir.constant(1 : index) : i64
%8 = llvm.call @fetch_kernel(%arg0, %7) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.return
}
}
-reconcile-unrealized-casts
module {
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = llvm.mlir.constant(0 : index) : i64
%1 = llvm.call @fetch_run_ctx(%arg0, %0) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%2 = llvm.mlir.constant(1 : index) : i64
%3 = llvm.call @fetch_run_ctx(%arg0, %2) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%4 = llvm.mlir.constant(0 : index) : i64
%5 = llvm.call @fetch_kernel(%arg0, %4) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%6 = llvm.mlir.constant(1 : index) : i64
%7 = llvm.call @fetch_kernel(%arg0, %6) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.return
}
}
// -----// IR Dump After ExtractKernelLaunchTensorPass //----- //
module {
func.func @wrap0(%arg0: !okl.launcher_ctx) -> (tensor<1xf32>, tensor<1xf32>) {
%0 = "okl.get_tensor_from_arg"(%arg0) {tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<1xf32>
%1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%2 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%3 = "okl.get_tensor_as_ret"(%arg0, %1) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
%4 = "okl.get_tensor_as_ret"(%arg0, %2) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
return %3, %4 : tensor<1xf32>, tensor<1xf32>
}
}
// -----// IR Dump After TrimReturnAsVoidPass //----- //
module {
func.func @wrap0(%arg0: !okl.launcher_ctx) {
%0 = "okl.get_tensor_from_arg"(%arg0) {tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<1xf32>
%1 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%2 = "oneflow.relu"(%0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
%3 = "okl.get_tensor_as_ret"(%arg0, %1) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
%4 = "okl.get_tensor_as_ret"(%arg0, %2) {tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<1xf32>) -> tensor<1xf32>
return
}
}
// -----// IR Dump After LowerToOKLPass //----- //
module {
func.func @_mlir__mlir_ciface_okl_func(%arg0: !okl.launcher_ctx) attributes {compiled = "true"} {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%1 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.build_op_kernel"(%0) : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%1, %2) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%0) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%1) : (!okl.run_ctx) -> ()
%3 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%4 = "okl.build_run_ctx"(%3, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%5 = "okl.build_op_kernel"(%3) : (!okl.reg_ctx) -> !okl.kernel
"okl.launch"(%4, %5) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.destroy_reg_ctx"(%3) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%4) : (!okl.run_ctx) -> ()
return
}
}
// -----// IR Dump After SplitIntoFuncsPass //----- //
module {
func.func @okl_recycle(%arg0: !okl.launcher_ctx) {
%0:2 = call @get_resources_type_0(%arg0) : (!okl.launcher_ctx) -> (!okl.reg_ctx, !okl.reg_ctx)
%1:2 = call @get_resources_type_1(%arg0) : (!okl.launcher_ctx) -> (!okl.run_ctx, !okl.run_ctx)
%2:2 = call @get_resources_type_2(%arg0) : (!okl.launcher_ctx) -> (!okl.kernel, !okl.kernel)
"okl.destroy_reg_ctx"(%0#0) : (!okl.reg_ctx) -> ()
"okl.destroy_reg_ctx"(%0#1) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%1#0) : (!okl.run_ctx) -> ()
"okl.destroy_run_ctx"(%1#1) : (!okl.run_ctx) -> ()
return
}
func.func @okl_compute(%arg0: !okl.launcher_ctx) {
%0:2 = call @get_resources_type_0(%arg0) : (!okl.launcher_ctx) -> (!okl.reg_ctx, !okl.reg_ctx)
%1:2 = call @get_resources_type_1(%arg0) : (!okl.launcher_ctx) -> (!okl.run_ctx, !okl.run_ctx)
%2:2 = call @get_resources_type_2(%arg0) : (!okl.launcher_ctx) -> (!okl.kernel, !okl.kernel)
"okl.launch"(%1#0, %2#0) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%1#1, %2#1) : (!okl.run_ctx, !okl.kernel) -> ()
return
}
func.func @okl_init_context(%arg0: !okl.launcher_ctx) {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%1 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%2 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%3 = "okl.build_run_ctx"(%1, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%4 = "okl.build_op_kernel"(%0) : (!okl.reg_ctx) -> !okl.kernel
%5 = "okl.build_op_kernel"(%1) : (!okl.reg_ctx) -> !okl.kernel
return
}
func.func private @get_resources_type_0(!okl.launcher_ctx) -> (!okl.reg_ctx, !okl.reg_ctx)
func.func private @get_resources_type_1(!okl.launcher_ctx) -> (!okl.run_ctx, !okl.run_ctx)
func.func private @get_resources_type_2(!okl.launcher_ctx) -> (!okl.kernel, !okl.kernel)
}
// -----// IR Dump After FetchFromLauncherPass //----- //
module {
func.func @okl_recycle(%arg0: !okl.launcher_ctx) {
%0 = "okl.fetch_reg_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.reg_ctx
%1 = "okl.fetch_reg_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.reg_ctx
%2 = "okl.fetch_run_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%3 = "okl.fetch_run_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
"okl.destroy_reg_ctx"(%0) : (!okl.reg_ctx) -> ()
"okl.destroy_reg_ctx"(%1) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%2) : (!okl.run_ctx) -> ()
"okl.destroy_run_ctx"(%3) : (!okl.run_ctx) -> ()
return
}
func.func @okl_compute(%arg0: !okl.launcher_ctx) {
%0 = "okl.fetch_run_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%1 = "okl.fetch_run_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.fetch_kernel"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.kernel
%3 = "okl.fetch_kernel"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.kernel
"okl.launch"(%0, %2) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%1, %3) : (!okl.run_ctx, !okl.kernel) -> ()
return
}
func.func @okl_init_context(%arg0: !okl.launcher_ctx) {
%0 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%1 = "okl.build_reg_ctx"() ({
^bb0(%arg1: tensor<1xf32>):
%6 = "oneflow.relu"(%arg1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 12 : i64} : (tensor<1xf32>) -> tensor<1xf32>
okl.return %6 : tensor<1xf32>
}) {function_type = (tensor<1xf32>) -> tensor<1xf32>} : () -> !okl.reg_ctx
%2 = "okl.build_run_ctx"(%0, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%3 = "okl.build_run_ctx"(%1, %arg0) : (!okl.reg_ctx, !okl.launcher_ctx) -> !okl.run_ctx
%4 = "okl.build_op_kernel"(%0) : (!okl.reg_ctx) -> !okl.kernel
%5 = "okl.build_op_kernel"(%1) : (!okl.reg_ctx) -> !okl.kernel
return
}
func.func private @get_resources_type_0(!okl.launcher_ctx) -> (!okl.reg_ctx, !okl.reg_ctx)
func.func private @get_resources_type_1(!okl.launcher_ctx) -> (!okl.run_ctx, !okl.run_ctx)
func.func private @get_resources_type_2(!okl.launcher_ctx) -> (!okl.kernel, !okl.kernel)
}
// -----// IR Dump After OnlyKeepComputeOpsPass //----- //
module {
func.func @okl_compute(%arg0: !okl.launcher_ctx) {
%0 = "okl.fetch_run_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%1 = "okl.fetch_run_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.fetch_kernel"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.kernel
%3 = "okl.fetch_kernel"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.kernel
"okl.launch"(%0, %2) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%1, %3) : (!okl.run_ctx, !okl.kernel) -> ()
return
}
}
// -----// IR Dump After LowerOKLToLLVMFuncPass //----- //
module {
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx
%1 = "okl.fetch_run_ctx"(%0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.fetch_run_ctx"(%0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%3 = "okl.fetch_kernel"(%0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.kernel
%4 = "okl.fetch_kernel"(%0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.kernel
"okl.launch"(%1, %3) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%2, %4) : (!okl.run_ctx, !okl.kernel) -> ()
llvm.return
}
}
// -----// IR Dump After LowerOKLToLLVMCallPass //----- //
module {
llvm.func @launch(!llvm.ptr<i8>, !llvm.ptr<i8>) attributes {llvm.emit_c_interface}
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.call @fetch_run_ctx(%arg0, %1) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%3 = llvm.mlir.constant(1 : index) : i64
%4 = llvm.call @fetch_run_ctx(%arg0, %3) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%5 = llvm.mlir.constant(0 : index) : i64
%6 = llvm.call @fetch_kernel(%arg0, %5) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%7 = llvm.mlir.constant(1 : index) : i64
%8 = llvm.call @fetch_kernel(%arg0, %7) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.call @launch(%2, %6) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.call @launch(%4, %8) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.return
}
}
// -----// IR Dump After ReconcileUnrealizedCasts //----- //
module {
llvm.func @launch(!llvm.ptr<i8>, !llvm.ptr<i8>) attributes {llvm.emit_c_interface}
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = llvm.mlir.constant(0 : index) : i64
%1 = llvm.call @fetch_run_ctx(%arg0, %0) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%2 = llvm.mlir.constant(1 : index) : i64
%3 = llvm.call @fetch_run_ctx(%arg0, %2) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%4 = llvm.mlir.constant(0 : index) : i64
%5 = llvm.call @fetch_kernel(%arg0, %4) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%6 = llvm.mlir.constant(1 : index) : i64
%7 = llvm.call @fetch_kernel(%arg0, %6) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.call @launch(%1, %5) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.call @launch(%3, %7) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.return
}
}
module {
llvm.func @launch(!llvm.ptr<i8>, !llvm.ptr<i8>) attributes {llvm.emit_c_interface}
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) {
%0 = llvm.mlir.constant(0 : index) : i64
%1 = llvm.call @fetch_run_ctx(%arg0, %0) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%2 = llvm.mlir.constant(1 : index) : i64
%3 = llvm.call @fetch_run_ctx(%arg0, %2) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%4 = llvm.mlir.constant(0 : index) : i64
%5 = llvm.call @fetch_kernel(%arg0, %4) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%6 = llvm.mlir.constant(1 : index) : i64
%7 = llvm.call @fetch_kernel(%arg0, %6) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.call @launch(%1, %5) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.call @launch(%3, %7) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.return
}
}
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
设置 ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH = 1,启动oneflow kernel launch功能打包计算型op。
在roundtrip的结尾,用于计算的连续的op被合并成单个kernel launch op oneflow ops -> oneflow.kernel_launch{mlir_assembly="wrap"} 例如
module {
oneflow.job @GraphToRun_1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_input.0.0_2", output_lbns = ["_GraphToRun_1_input.0.0_2/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>
%0 = "oneflow.relu"(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%output_0 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_output.0.0_2", output_lbns = ["_GraphToRun_1_output.0.0_2/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>
oneflow.return %output_0 : tensor<2xf32>
}
}
中的连续的relu和tanh两个op被合并成如下单个的kernel launch op,其资源在kernel launch op的wrap 函数中以一定的规则得以映射
module {
func.func @wrap0(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) attributes {llvm.emit_c_interface} {
%0 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
return %0, %1 : tensor<2xf32>, tensor<2xf32>
}
oneflow.job @GraphToRun_1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_input.0.0_2", output_lbns = ["_GraphToRun_1_input.0.0_2/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>
%0:2 = oneflow.kernel_launch @wrap0(%output) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], mlir_assembly = "\22func.func\22() ({\0A^bb0(%arg0: tensor<2xf32>):\0A %0 = \22oneflow.relu\22(%arg0) {device_name = [\22@0:0\22], device_tag = \22cuda\22, hierarchy = [1], op_name = \22relu-0\22, scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>\0A %1 = \22oneflow.tanh\22(%0) {device_name = [\22@0:0\22], device_tag = \22cuda\22, hierarchy = [1], op_name = \22tanh-1\22, scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>\0A \22func.return\22(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> ()\0A}) {function_type = (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>), llvm.emit_c_interface, sym_name = \22wrap0\22} : () -> ()", op_name = "wrap0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%output_0 = "oneflow.output"(%0#1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_GraphToRun_1_output.0.0_2", output_lbns = ["_GraphToRun_1_output.0.0_2/out"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>
oneflow.return %output_0 : tensor<2xf32>
}
}
在kernel launch的类中,通过wrap函数的func type来推导input和output形状
- set input shape(infer from function_type operands)
- set output shape(infer from function_type results)
在kernel launch的初始化过程中
- lower mlir_assembly="wrap" to okl module
- globally create launcher context(init_context, okl_module) once
- vec<reg_ctx> : datatype/device/input/output(information from wrap op)
- vec<run_ctx> : input/output for kernel compute function(information from reg_ctx, tensor from init_context)
- vec
: datatype/device(information from reg_ctx)
- lower okl compute to llvm module
其中 lower oneflow to okl 将:
module {
func.func @wrap0(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) attributes {llvm.emit_c_interface} {
%0 = "oneflow.relu"(%arg0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%1 = "oneflow.tanh"(%0) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
return %0, %1 : tensor<2xf32>, tensor<2xf32>
}
通过如下几个pass转换成okl格式:
- -extract-kernel-launch-tensor
- -trim-return-to-void
- -lower-to-okl
- -split-into-funcs
- -fetch-from-launcher 结果为:
module {
func.func @okl_recycle(%arg0: !okl.launcher_ctx) {
%0 = "okl.fetch_reg_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.reg_ctx
%1 = "okl.fetch_reg_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.reg_ctx
%2 = "okl.fetch_run_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%3 = "okl.fetch_run_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
"okl.destroy_reg_ctx"(%0) : (!okl.reg_ctx) -> ()
"okl.destroy_reg_ctx"(%1) : (!okl.reg_ctx) -> ()
"okl.destroy_run_ctx"(%2) : (!okl.run_ctx) -> ()
"okl.destroy_run_ctx"(%3) : (!okl.run_ctx) -> ()
return
}
func.func @okl_compute(%arg0: !okl.launcher_ctx) {
%0 = "okl.fetch_run_ctx"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%1 = "okl.fetch_run_ctx"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.run_ctx
%2 = "okl.fetch_kernel"(%arg0) {index = 0 : si64} : (!okl.launcher_ctx) -> !okl.kernel
%3 = "okl.fetch_kernel"(%arg0) {index = 1 : si64} : (!okl.launcher_ctx) -> !okl.kernel
"okl.launch"(%0, %2) : (!okl.run_ctx, !okl.kernel) -> ()
"okl.launch"(%1, %3) : (!okl.run_ctx, !okl.kernel) -> ()
return
}
func.func @okl_init_context(%arg0: !okl.launcher_ctx) {
%0 = "okl.build_reg_ctx"() ({
%6 = "okl.get_tensor_from_arg"(%arg0) {index = 0 : i32, tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>
%7 = "oneflow.relu"(%6) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "relu-0", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%8 = "okl.get_tensor_as_ret"(%arg0, %7) {index = 0 : i32, tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>
okl.return
}) {function_type = () -> ()} : () -> !okl.reg_ctx
%1 = "okl.build_reg_ctx"() ({
%6 = "okl.get_tensor_from_ret"(%arg0) {index = 0 : i32, tensor_type = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>
%7 = "oneflow.tanh"(%6) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "tanh-1", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>
%8 = "okl.get_tensor_as_ret"(%arg0, %7) {index = 1 : i32, tensor_type = 2 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>
okl.return
}) {function_type = () -> ()} : () -> !okl.reg_ctx
%2 = "okl.build_run_ctx"(%0) : (!okl.reg_ctx) -> !okl.run_ctx
%3 = "okl.build_run_ctx"(%1) : (!okl.reg_ctx) -> !okl.run_ctx
%4 = "okl.build_op_kernel"(%0) : (!okl.reg_ctx) -> !okl.kernel
%5 = "okl.build_op_kernel"(%1) : (!okl.reg_ctx) -> !okl.kernel
return
}
func.func private @get_resources_type_0(!okl.launcher_ctx) -> (!okl.reg_ctx, !okl.reg_ctx)
func.func private @get_resources_type_1(!okl.launcher_ctx) -> (!okl.run_ctx, !okl.run_ctx)
func.func private @get_resources_type_2(!okl.launcher_ctx) -> (!okl.kernel, !okl.kernel)
}
其中okl_init_context用来实现第二部的资源初始化,okl_compute用来后面调用各自的kernel运行compute。
通过如下几个pass,将okl转换成llvm格式:
- -only-keep-compute-ops
- -lower-launcher-to-llvm-ptr
- -lower-okl-to-llvm-call
- -reconcile-unrealized-casts
- -convert-func-to-llvm
module attributes {llvm.data_layout = ""} {
llvm.func @launch(!llvm.ptr<i8>, !llvm.ptr<i8>) attributes {llvm.emit_c_interface}
llvm.func @fetch_kernel(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @fetch_run_ctx(!llvm.ptr<i8>, i64) -> !llvm.ptr<i8> attributes {llvm.emit_c_interface}
llvm.func @okl_compute(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {
%0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.call @fetch_run_ctx(%arg0, %1) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%3 = llvm.mlir.constant(1 : index) : i64
%4 = llvm.call @fetch_run_ctx(%arg0, %3) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%5 = llvm.mlir.constant(0 : index) : i64
%6 = llvm.call @fetch_kernel(%arg0, %5) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
%7 = llvm.mlir.constant(1 : index) : i64
%8 = llvm.call @fetch_kernel(%arg0, %7) : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
llvm.call @launch(%2, %6) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.call @launch(%4, %8) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
llvm.return
}
llvm.func @_mlir_ciface_okl_compute(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {
llvm.call @okl_compute(%arg0) : (!llvm.ptr<i8>) -> ()
llvm.return
}
}
最后compute函数中通过llvm engine运行该llvm Compute[may execute many times]
- llvm.engine(llvm module, launcher context)
- fetch run_ctx(from launcher context with index)
- fetch kernel(from launcher context with index)
- launch kernel(run_ctx)
llvm.call的callee逻辑在liboneflow.so,通过extern C实现
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
CI failed when running job: Build cpu. PR label automerge has been removed
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 139.6ms (= 13961.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 160.5ms (= 16048.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 160.5ms / 139.6ms)
OneFlow resnet50 time: 84.7ms (= 8471.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.7ms (= 10170.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 101.7ms / 84.7ms)
OneFlow resnet50 time: 57.2ms (= 11431.7ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.7ms (= 15538.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 77.7ms / 57.2ms)
OneFlow resnet50 time: 44.5ms (= 8891.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.3ms (= 13664.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.54 (= 68.3ms / 44.5ms)
OneFlow resnet50 time: 39.9ms (= 7976.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.5ms (= 14503.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.82 (= 72.5ms / 39.9ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9144/
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 139.6ms (= 13960.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 160.7ms (= 16065.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 160.7ms / 139.6ms)
OneFlow resnet50 time: 85.1ms (= 8514.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.1ms (= 10210.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 102.1ms / 85.1ms)
OneFlow resnet50 time: 58.0ms (= 11608.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.7ms (= 17744.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.53 (= 88.7ms / 58.0ms)
OneFlow resnet50 time: 44.5ms (= 8890.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.8ms (= 14162.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.59 (= 70.8ms / 44.5ms)
OneFlow resnet50 time: 39.4ms (= 7889.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.8ms (= 14156.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.79 (= 70.8ms / 39.4ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9144/
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 139.5ms (= 13950.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.5ms (= 16148.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.16 (= 161.5ms / 139.5ms)
OneFlow resnet50 time: 85.3ms (= 8529.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.3ms (= 10134.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 101.3ms / 85.3ms)
OneFlow resnet50 time: 57.5ms (= 11507.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.4ms (= 15689.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 78.4ms / 57.5ms)
OneFlow resnet50 time: 44.5ms (= 8900.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 71.4ms (= 14289.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.61 (= 71.4ms / 44.5ms)
OneFlow resnet50 time: 40.2ms (= 8044.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.9ms (= 13570.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.69 (= 67.9ms / 40.2ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9144/
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 139.5ms (= 13949.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 168.0ms (= 16795.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 168.0ms / 139.5ms)
OneFlow resnet50 time: 84.8ms (= 8478.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 108.0ms (= 10798.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.27 (= 108.0ms / 84.8ms)
OneFlow resnet50 time: 57.6ms (= 11529.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.0ms (= 17594.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.53 (= 88.0ms / 57.6ms)
OneFlow resnet50 time: 44.3ms (= 8863.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.9ms (= 14186.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.60 (= 70.9ms / 44.3ms)
OneFlow resnet50 time: 39.2ms (= 7837.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.4ms (= 13283.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.69 (= 66.4ms / 39.2ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9144/
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 140.2ms (= 14024.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.8ms (= 16282.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.16 (= 162.8ms / 140.2ms)
OneFlow resnet50 time: 86.1ms (= 8610.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.1ms (= 10211.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 102.1ms / 86.1ms)
OneFlow resnet50 time: 57.8ms (= 11570.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.1ms (= 17612.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.52 (= 88.1ms / 57.8ms)
OneFlow resnet50 time: 44.7ms (= 8939.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.8ms (= 13754.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.54 (= 68.8ms / 44.7ms)
OneFlow resnet50 time: 40.1ms (= 8011.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 76.9ms (= 15377.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.92 (= 76.9ms / 40.1ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9144/
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 142.1ms (= 14205.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 165.7ms (= 16567.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.17 (= 165.7ms / 142.1ms)
OneFlow resnet50 time: 87.6ms (= 8755.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 113.8ms (= 11376.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.30 (= 113.8ms / 87.6ms)
OneFlow resnet50 time: 58.6ms (= 11724.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 83.0ms (= 16593.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.42 (= 83.0ms / 58.6ms)
OneFlow resnet50 time: 45.0ms (= 8999.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.6ms (= 14119.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.57 (= 70.6ms / 45.0ms)
OneFlow resnet50 time: 40.4ms (= 8074.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 74.3ms (= 14869.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.84 (= 74.3ms / 40.4ms)