TabPFN icon indicating copy to clipboard operation
TabPFN copied to clipboard

Create Sklearn Interface for FineTuning TabPFN

Open klemens-floege opened this issue 5 months ago • 0 comments

Describe the workflow you want to enable

The current fine-tuning process is implemented as a standalone script. While this demonstrates the functionality well, it is not easily integrated into standard machine learning pipelines that rely on the scikit-learn API (e.g., for cross-validation, hyperparameter tuning, or inclusion in a sklearn.pipeline.Pipeline). This limits the reusability and interoperability of the fine-tuning workflow.

Describe your proposed solution

To improve usability and integration, the fine-tuning logic should be encapsulated within a scikit-learn compatible classifier. This would provide a standard .fit(), .predict(), and .predict_proba() interface, making the fine-tuning of the model much more user-friendly.

Here is a proposed structure for the new class:

import sklearn.base
import pandas as pd
from tabpfn import TabPFNClassifier

class FinetunedTabPFNClassifier(sklearn.base.ClassifierMixin, sklearn.base.BaseEstimator):
    """
    A scikit-learn compatible wrapper for a fine-tuned TabPFNClassifier.
    """
    def __init__(self, epochs: int = 10, learning_rate: float = 1e-5):
        # Hyperparameters go in the constructor
        self.epochs = epochs
        self.learning_rate = learning_rate

    def fit(self, X, y):
        """
        Initializes a TabPFNClassifier and fine-tunes it on the data (X, y).
        """
        # 1. Initialize the base estimator inside fit
        self.base_estimator_ = TabPFNClassifier(device='cuda', n_estimators=2)
        
        # 2. The fine-tuning logic from the script would go here.
        #    It would use self.epochs, self.learning_rate, and the data X, y.
        #    ...
        
        # 3. Store the fitted model in an attribute ending with an underscore
        self.classes_ = # unique classes from y
        
        return self

    def predict(self, X):
        # Use the fitted estimator for prediction
        sklearn.utils.validation.check_is_fitted(self)
        return self.base_estimator_.predict(X)

    def predict_proba(self, X):
        # Use the fitted estimator for probability prediction
        sklearn.utils.validation.check_is_fitted(self)
        return self.base_estimator_.predict_proba(X)

Describe alternatives you've considered, if relevant

One alternative is to keep the current script-based approach and provide extensive documentation on how to adapt it. However, this places a larger burden on the end-user and prevents seamless integration with other tools.

Another option is to provide a set of helper functions, but a dedicated class offers a much cleaner, more intuitive, and standardized API.

Additional context

No response

Impact

None

klemens-floege avatar Jul 22 '25 16:07 klemens-floege