keras icon indicating copy to clipboard operation
keras copied to clipboard

Think about Keras3 preference, should I develop my distributed training CTR model based on Jax or TensorFlow?

Open MoFHeka opened this issue 1 year ago • 1 comments

I am a developer of tensorflow recommenders-addons and I now need to develop an all-to-all embedding layer for multi-GPU distributed training of recommendation models. The old tensorflow distributed strategy clearly did not meet this need. So the question is, should I develop on TF DTensor or Jax? Because it seems that Keras support for TF DTensor is not friendly. But Jax lacks the ability to online inference services and the functional components used by various recommendation algorithms. Also recommenders-addons has a lot of custom operators.

MoFHeka avatar Jun 17 '24 19:06 MoFHeka

If you need a SPMD API, I 100% recommend JAX. It's more mature and better optimized than DTensor. You can use it via the easy keras.distribution API with the JAX backend. https://keras.io/guides/distribution/

fchollet avatar Jun 18 '24 00:06 fchollet

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Jan 11 '25 02:01 github-actions[bot]

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

github-actions[bot] avatar Jan 26 '25 02:01 github-actions[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Jan 26 '25 02:01 google-ml-butler[bot]