CCNet
CCNet copied to clipboard
Hope to support mixed-precision training
The function rcca.ca_forward_cuda
seems doesn't support mixed-precision training and raise error:
RuntimeError: expected scalar type Float but found Half (data
when the opt_level of NVIDIA APEX is set to "O1"
While it goes well when I shift opt_level to "O0".
So it should be the reason of data type.
@DotWang We do not have a plan to make it support mixed-precision training. It will be great if you can achieve it. Maybe there is an alternative way to achieve it. You can replace the Cuda implemented RCCA module with Pytorch pure one.
@DotWang We do not have a plan to make it support mixed-precision training. It will be great if you can achieve it. Maybe there is an alternative way to achieve it. You can replace the Cuda implemented RCCA module with Pytorch pure one.
OK, Thank you~