Fix output tensor dimension in degenerate `optimize_acqf` call
Motivation
Fixes #2740
@sdaulton: This is a draft for fixing the dimensionality issue. While the PR technically solves the problem, think it only addresses the symptom and not the underlying root cause. For the latter, I'd need some guidance from your end.
Specifically, something seems fishy about the tensor dimensions promised by the acquisition function base class. The docstring states that the input is a (b) x q x d-dim tensor and the output is a (b)-dim. The way I understand this is that b is an optional batching dimension. But if you read the contract, it would imply:
- If you pass a
b x q x d-dim tensor, you get ab-dim tensor back. - If you pass a
q x d-dim tensor, you get a dimensionless tensor (i.e. scalar value) back.
Now the problem is that the second part seems to be actually violated, which causes the the problem in #2740. Interestingly, when writing the test, I wanted to use a MockAcquisitionFunction() object for it like done in other tests. However, that object does seem to fulfill the contract, so I couldn't even reproduce the problem with the mock but had to reuse the code from my original minimal example.
Any thoughts how we should go about this, i.e. is this actually a more fundamental problem in the acquisition function code?
Have you read the Contributing Guidelines on pull requests?
Yes
Test Plan
A corresponding test has been included.
@sdaulton has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
Thanks for the PR! It looks like a test is failing. Would you mind fixing that please?
Thanks for the PR! It looks like a test is failing. Would you mind fixing that please?
Sorry, my bad, forgot to run the entire test suite. Should now be fixed – though there are 4 other tests failing also on the main branch?
Still, like I mentioned earlier: while the change fixes the particular problem reported, I don't think it properly addresses the root cause. So input would be welcome 🙃
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.99%. Comparing base (
9a7c517) to head (3ced166).
Additional details and impacted files
@@ Coverage Diff @@
## main #2743 +/- ##
=======================================
Coverage 99.99% 99.99%
=======================================
Files 203 203
Lines 18690 18692 +2
=======================================
+ Hits 18689 18691 +2
Misses 1 1
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.