TabPFN icon indicating copy to clipboard operation
TabPFN copied to clipboard

Print warning for larger datasets when run on CPU

Open LeoGrin opened this issue 10 months ago • 4 comments

Describe the workflow you want to enable

Creating this new issue to be able to close #100

Describe your proposed solution

Print warning for larger datasets when run on CPU (>1000 samples? or rather when the estimated runtime is very long? > 10 min). Warning should say that user can try our API client instead or move to GPU.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

Impact

None

LeoGrin avatar Feb 17 '25 16:02 LeoGrin

I think we add below code to regressor.py in fit method in order to warn users on CPU when the dataset is large

import warnings
if self.device == 'cpu' and X.shape[0] > 10000:  
    warnings.warn("Running on CPU with a large dataset may be slow. Consider using a GPU.")

@noahho @LeoGrin please tell me whether I am thinking in right direction.

Krishnadubey1008 avatar Mar 18 '25 07:03 Krishnadubey1008

Yes exactly, I'd just put this at 1000 samples and also include a reference to using the API if GPU not available https://github.com/PriorLabs/tabpfn-client

Longer term, I'd make a larger change where we estimate the time taken, similar to: https://github.com/PriorLabs/TabPFN/blob/main/src/tabpfn/model/memory.py#L107

noahho avatar Mar 18 '25 09:03 noahho

@noahho I had raised an PR applying "Yes exactly, I'd just put this at 1000 samples and also include a reference to using the API if GPU not available https://github.com/PriorLabs/tabpfn-client" changes, please review it. I will try to implement long term solution also.

Krishnadubey1008 avatar Mar 18 '25 10:03 Krishnadubey1008

Hi, thanks for opening this issue! I'm interested in working on this feature. I have some experience with Python and optimizing ML workflows, so I’m excited to contribute.

Before I get started, I have a couple of questions to ensure I fully understand the requirements:

For the warning threshold, would you prefer to trigger it based on a fixed sample count (e.g., >1000 samples) or an estimated runtime (e.g., >10 minutes), or perhaps a combination of both? Should the warning be printed as a terminal output (e.g., using Python’s logging module) or integrated into the API’s response?

Looking forward to your guidance so I can implement this in a way that best fits the project's workflow. Thanks!

Best regards, Rishi

RishiP2006 avatar Mar 23 '25 11:03 RishiP2006