Retrieval Task Mixed Precision Dtype Issue
For a mixed precision approach, I observed that the query_embeddings for the retrieval task gets cast to the global precision policy while the candidate_embedding preserves its dtype. It causes the linalg computation in the call method in retrieval to error due to different dtypes.

I have the same problem!
This is pretty annoying and have been grappling with this for days. Does anyone know, where in the code this is enforced (i.e. query_embeddings getting cast to the global_policy dtype?
One workaround is to keep the global_policy in float32 and make tf.keras.Model objects of type mixed_float16. Here's an example code.
Caveat: I don't really know if such model live up to the performance gains obtained with mixed precision training. Nor that I have done a proper accuracy comparison yet. Nonetheless, it does work and from the tensorboard profiling I did, it does use float16 on the GPU.
Another caveat: Looks like some layers don't derive dtype from the model (e.g. Embeddings, Cross, ...) so for those, you need to specify dtype explicitly in the layer construction.
mixed_precision.set_global_policy('float32')
class TwoTowerModel(tfrs.models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.query_model = QueryModel(dtype="mixed_float16")
self.candidate_model = CandidateModel(dtype="mixed_float16")
self.task = tfrs.tasks.Retrieval(...)
def compute_loss(self, features, training=False):
query_embeddings = tf.cast(self.query_model(features["query"]), tf.float32)
item_embeddings = tf.cast(self.query_model(features["item"]), tf.float32)
return self.task(query_embeddings=query_embeddings, candidate_embeddings=item_embeddings, compute_metrics=not training)
Keras automatically casts input to layers depending on their dtype.
The issue is that the Retrieval task is a keras layer for which the dtype is set as the global policy. The dtype parameter can't be set for this layer curently. This issue also applies to the loss correction layers SamplingBiasCorrection, RemoveAccidentalHits, HardNegativeMining.
So essentially it's best to follow @thushv89 advice for now.
If desired @maciejkula, I'm happy to contribute a PR to expose the layer dtype for the task and loss layers?
Same problem here!
+1