TabPFN icon indicating copy to clipboard operation
TabPFN copied to clipboard

Add TPU support

Open LeoGrin opened this issue 8 months ago • 1 comments

Describe the workflow you want to enable

On platforms like Kaggle, the available GPU hardware (e.g., P100, T4) is often dated and heavily utilized. In contrast, powerful TPUs are frequently underused. Enabling TabPFN to run on TPUs would offer a compelling and high-performance alternative for users, particularly within the Kaggle community.

Describe your proposed solution

Allow to use TPUs as a device, instead of just CPUs and GPUs. TPUs should also be detected if device is "auto". Ideally, installation of torch_xla should be optional I think.

Unfortunately while this enables TPU usage, the actual utilization is very low. More optimization is needed.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

Impact

None

LeoGrin avatar Apr 07 '25 11:04 LeoGrin

PR by @MagnusBuehler looked into this https://github.com/PriorLabs/TabPFN/pull/424. However speed improvements seem marginal compared to CPU, more work would be needed to make our architecture efficient on TPU.

noahho avatar Sep 13 '25 13:09 noahho