recommenders icon indicating copy to clipboard operation
recommenders copied to clipboard

Retrieval Task Mixed Precision Dtype Issue

Open Pelps12 opened this issue 3 years ago • 7 comments

For a mixed precision approach, I observed that the query_embeddings for the retrieval task gets cast to the global precision policy while the candidate_embedding preserves its dtype. It causes the linalg computation in the call method in retrieval to error due to different dtypes. Screenshot 2022-06-05 022000

Pelps12 avatar Jun 05 '22 07:06 Pelps12

I have the same problem!

Hedluund avatar Dec 19 '22 16:12 Hedluund

This is pretty annoying and have been grappling with this for days. Does anyone know, where in the code this is enforced (i.e. query_embeddings getting cast to the global_policy dtype?

One workaround is to keep the global_policy in float32 and make tf.keras.Model objects of type mixed_float16. Here's an example code.

Caveat: I don't really know if such model live up to the performance gains obtained with mixed precision training. Nor that I have done a proper accuracy comparison yet. Nonetheless, it does work and from the tensorboard profiling I did, it does use float16 on the GPU.

Another caveat: Looks like some layers don't derive dtype from the model (e.g. Embeddings, Cross, ...) so for those, you need to specify dtype explicitly in the layer construction.

mixed_precision.set_global_policy('float32')

class TwoTowerModel(tfrs.models.Model):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.query_model = QueryModel(dtype="mixed_float16")
        self.candidate_model = CandidateModel(dtype="mixed_float16")
        self.task = tfrs.tasks.Retrieval(...)

    def compute_loss(self, features, training=False):
        
        query_embeddings = tf.cast(self.query_model(features["query"]), tf.float32)
        item_embeddings = tf.cast(self.query_model(features["item"]), tf.float32)

        return self.task(query_embeddings=query_embeddings, candidate_embeddings=item_embeddings, compute_metrics=not training)

thushv89 avatar Dec 28 '22 02:12 thushv89

Keras automatically casts input to layers depending on their dtype. The issue is that the Retrieval task is a keras layer for which the dtype is set as the global policy. The dtype parameter can't be set for this layer curently. This issue also applies to the loss correction layers SamplingBiasCorrection, RemoveAccidentalHits, HardNegativeMining.

So essentially it's best to follow @thushv89 advice for now.

If desired @maciejkula, I'm happy to contribute a PR to expose the layer dtype for the task and loss layers?

patrickorlando avatar Feb 01 '23 05:02 patrickorlando

Same problem here!

danielmejiaMCO avatar Mar 23 '23 03:03 danielmejiaMCO

+1

sergeygeo avatar Mar 19 '24 08:03 sergeygeo