Paddle
Paddle copied to clipboard
[XPU] add fused_linear_param_grad_add kernel
PR Category
Custom Device
PR Types
New features
Description
新增fused_linear_param_grad_add算子。为了实现这个算子,做了以下事情:
- 参考GPU的实现,将现有的算子
paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_grad_kernel.cc代码进行了拆分,拆出来一个新的头文件paddle/phi/kernels/funcs/fused_gemm_epilogue_xpu.h。 - 参考GPU的实现,在这个文件中
paddle/phi/kernels/fusion/xpu/fused_linear_param_grad_add_kernel.cc新增了算子实现,引用了上面拆出来的那个头文件。
局限性:目前只支持float32类型。但是GPU中带来的multi_precision相关的代码仍然保留,等MatMulXPUFunction及其周边矩阵乘法基础组件支持多种数据类型以后再加其它数据类型的支持。