TabPFN
TabPFN copied to clipboard
Support for 'mps' device (i.e. for Apple Silicon)
It would be great if the model would support Apple Silicon devices so that the 'mps' device could be used as follows:
model = TabPFNClassifier(device='mps', N_ensemble_configurations=16)
Sorry, that we never replied for this! We'll keep it in mind, but unfortunately have a lot to do. If you'd like to create a PR for this, would be welcome :)
I get an error that MPS device does not support 64-bit floats