ranking
ranking copied to clipboard
Clarifying global floating point policy
The issue of float precision affects many computations in tensorflow_ranking
, such as
https://github.com/tensorflow/ranking/blob/a928e2b1930a1ebcae2c509e3f6ca95941fd1e49/tensorflow_ranking/python/metrics_impl.py#L603-L628
This has been mentioned before in #254, but I want to elaborate on our difficulties.
This type of hardcoded dtypes makes it extremely hard to move our programs to use float64
.
For example, if we use tf.keras.backend.set_floatx('float64')
anywhere, we get errors within tensorflow_ranking
due to conflicting dtypes.
Will the global floating point policy (tf.keras.mixed_precision.set_global_policy
and tf.keras.backend.floatx
) be supported?
If the official stance on the global policy is to ignore it, can it be documented?