Create Sklearn Interface for FineTuning TabPFN
Describe the workflow you want to enable
The current fine-tuning process is implemented as a standalone script. While this demonstrates the functionality well, it is not easily integrated into standard machine learning pipelines that rely on the scikit-learn API (e.g., for cross-validation, hyperparameter tuning, or inclusion in a sklearn.pipeline.Pipeline). This limits the reusability and interoperability of the fine-tuning workflow.
Describe your proposed solution
To improve usability and integration, the fine-tuning logic should be encapsulated within a scikit-learn compatible classifier. This would provide a standard .fit(), .predict(), and .predict_proba() interface, making the fine-tuning of the model much more user-friendly.
Here is a proposed structure for the new class:
import sklearn.base
import pandas as pd
from tabpfn import TabPFNClassifier
class FinetunedTabPFNClassifier(sklearn.base.ClassifierMixin, sklearn.base.BaseEstimator):
"""
A scikit-learn compatible wrapper for a fine-tuned TabPFNClassifier.
"""
def __init__(self, epochs: int = 10, learning_rate: float = 1e-5):
# Hyperparameters go in the constructor
self.epochs = epochs
self.learning_rate = learning_rate
def fit(self, X, y):
"""
Initializes a TabPFNClassifier and fine-tunes it on the data (X, y).
"""
# 1. Initialize the base estimator inside fit
self.base_estimator_ = TabPFNClassifier(device='cuda', n_estimators=2)
# 2. The fine-tuning logic from the script would go here.
# It would use self.epochs, self.learning_rate, and the data X, y.
# ...
# 3. Store the fitted model in an attribute ending with an underscore
self.classes_ = # unique classes from y
return self
def predict(self, X):
# Use the fitted estimator for prediction
sklearn.utils.validation.check_is_fitted(self)
return self.base_estimator_.predict(X)
def predict_proba(self, X):
# Use the fitted estimator for probability prediction
sklearn.utils.validation.check_is_fitted(self)
return self.base_estimator_.predict_proba(X)
Describe alternatives you've considered, if relevant
One alternative is to keep the current script-based approach and provide extensive documentation on how to adapt it. However, this places a larger burden on the end-user and prevents seamless integration with other tools.
Another option is to provide a set of helper functions, but a dedicated class offers a much cleaner, more intuitive, and standardized API.
Additional context
No response
Impact
None