TabPFN icon indicating copy to clipboard operation
TabPFN copied to clipboard

Regression return logits

Open noahho opened this issue 5 months ago • 1 comments

Motivation and Context

This change introduces two major improvements. First, it directly responds to the request in Issue #374 for raw logits from TabPFNClassifier. Explaining model outputs (e.g., with Shapley values) is more numerically stable and often more intuitive on an additive scale like logits. This PR adds a predict_logits method to facilitate these use cases.

Second, this work includes a significant refactoring of the internal forward and prediction methods for both TabPFNClassifier and TabPFNRegressor. This simplifies the forward pass interface, especially for the regressor, and provides a key benefit: a substantial reduction in memory usage for TabPFNRegressor. By processing ensemble outputs sequentially instead of stacking them, the regressor can now handle larger datasets and a higher number of estimators more efficiently.


Public API Changes

  • [ ] No Public API changes
  • [x] Yes, Public API changes (Details below)

Details of Public API Changes:

  • TabPFNClassifier.predict_logits(X: XType) -> np.ndarray: A new public method that returns the raw, unnormalized logits for the input samples X.

  • TabPFNRegressor.forward(...) Signature Change (Breaking): The forward method for the regressor has been simplified. It now returns only a single tensor of logits, instead of the previous tuple (averaged_logits, outputs, borders). This is a breaking change for users who call forward() directly (e.g., in finetuning scripts). The new interface is simpler and more memory-efficient.

  • Internal Refactoring: The predict and predict_proba methods in both classes have been refactored to use the new internal logic. Their public signatures remain unchanged, ensuring backward compatibility for standard prediction workflows.


How Has This Been Tested?

This PR includes comprehensive new and updated test cases to ensure the correctness and consistency of the new functionality and refactored code paths:

  • predict_logits Consistency: A new test, test_predict_logits_and_consistency, has been added to tests/test_classifier_interface.py. This highly parametrized test verifies that predict_logits returns output of the correct shape and type. It crucially asserts that applying softmax to the raw logits yields results that are numerically close to the probabilities from predict_proba, ensuring consistency across various configurations.

  • Regressor Forward Pass: A new test test_forward_predict_logit_consistency was added to tests/test_regressor_interface.py to validate that the new, memory-efficient forward pass produces logits identical to the high-level predict method, ensuring correctness after the refactoring.

  • Behavioral Tests: New tests like test_softmax_temperature_impact_on_logits_magnitude and test_balance_probabilities_alters_proba_output confirm that the modular post-processing pipeline in the classifier behaves as expected.

  • Finetuning: Finetuning tests for the regressor (e.g., in tests/test_finetuning_regressor.py and examples/finetune_regressor.py) have been updated to align with the new, simplified forward method signature.


Checklist

  • [x] The changes have been tested locally.
  • [ ] Documentation has been updated (if the public API or usage changes).
  • [x] An entry has been added to CHANGELOG.md (if relevant for users).
  • [x] The code follows the project's style guidelines.
  • [x] I have considered the impact of these changes on the public API.

noahho avatar Jul 15 '25 16:07 noahho

@gemini-code-assist review

noahho avatar Jul 16 '25 08:07 noahho