Regression return logits
Motivation and Context
This change introduces two major improvements. First, it directly responds to the request in Issue #374 for raw logits from TabPFNClassifier. Explaining model outputs (e.g., with Shapley values) is more numerically stable and often more intuitive on an additive scale like logits. This PR adds a predict_logits method to facilitate these use cases.
Second, this work includes a significant refactoring of the internal forward and prediction methods for both TabPFNClassifier and TabPFNRegressor. This simplifies the forward pass interface, especially for the regressor, and provides a key benefit: a substantial reduction in memory usage for TabPFNRegressor. By processing ensemble outputs sequentially instead of stacking them, the regressor can now handle larger datasets and a higher number of estimators more efficiently.
Public API Changes
- [ ] No Public API changes
- [x] Yes, Public API changes (Details below)
Details of Public API Changes:
-
TabPFNClassifier.predict_logits(X: XType) -> np.ndarray: A new public method that returns the raw, unnormalized logits for the input samplesX. -
TabPFNRegressor.forward(...)Signature Change (Breaking): Theforwardmethod for the regressor has been simplified. It now returns only a single tensor of logits, instead of the previous tuple(averaged_logits, outputs, borders). This is a breaking change for users who callforward()directly (e.g., in finetuning scripts). The new interface is simpler and more memory-efficient. -
Internal Refactoring: The
predictandpredict_probamethods in both classes have been refactored to use the new internal logic. Their public signatures remain unchanged, ensuring backward compatibility for standard prediction workflows.
How Has This Been Tested?
This PR includes comprehensive new and updated test cases to ensure the correctness and consistency of the new functionality and refactored code paths:
-
predict_logitsConsistency: A new test,test_predict_logits_and_consistency, has been added totests/test_classifier_interface.py. This highly parametrized test verifies thatpredict_logitsreturns output of the correct shape and type. It crucially asserts that applying softmax to the raw logits yields results that are numerically close to the probabilities frompredict_proba, ensuring consistency across various configurations. -
Regressor Forward Pass: A new test
test_forward_predict_logit_consistencywas added totests/test_regressor_interface.pyto validate that the new, memory-efficientforwardpass produces logits identical to the high-levelpredictmethod, ensuring correctness after the refactoring. -
Behavioral Tests: New tests like
test_softmax_temperature_impact_on_logits_magnitudeandtest_balance_probabilities_alters_proba_outputconfirm that the modular post-processing pipeline in the classifier behaves as expected. -
Finetuning: Finetuning tests for the regressor (e.g., in
tests/test_finetuning_regressor.pyandexamples/finetune_regressor.py) have been updated to align with the new, simplifiedforwardmethod signature.
Checklist
- [x] The changes have been tested locally.
- [ ] Documentation has been updated (if the public API or usage changes).
- [x] An entry has been added to
CHANGELOG.md(if relevant for users). - [x] The code follows the project's style guidelines.
- [x] I have considered the impact of these changes on the public API.
@gemini-code-assist review