Using FP32 for DRQ network.
Moving the temperature numpy log to FP32.
I am wondering what is the motivation of this change?
If we are changing this, I hope we can update the upstream code at the same time: https://github.com/denisyarats/drq/blob/master/drq.py#L188
I am wondering what is the motivation of this change?
If we are changing this, I hope we can update the upstream code at the same time: https://github.com/denisyarats/drq/blob/master/drq.py#L188
Hi @xuzhao9 , the motivation is that Metal doesn't support FP64. And as a result this benchmark fails on MPS device. From the change it seems like by default np.log returns FP64 , if it won't have any impact on the networks performance. It will be nice to just use FP32.
@kulinseth Since we don't do any end-to-end accuracy checks in the core models set, we need to be very careful when modifying the model code.
Therefore, I believe we can only accept this PR if upstream model code repo uses FP32, otherwise we will keep it as-is.
cc @adnanaziz for ideas and suggestions.
Hi @xuzhao9 , that makes sense. I will create an upstream PR to propose similar fix.