Batched inference CEBRA & padding at the `Solver` level
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/624 fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/637 fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/594
@stes @MMathisLab, if you have time to review this that would be great :)
@CeliaBenquet can you solve the conflicts, then I think fine to merge!
@MMathisLab there's been big code changes / refactoring since @stes's last review, so I would be more confident about merging after an "in-depth" reviewing, but your call :)
reviewing now
re your comment on changes not related to the batched inference, it is because the PR was started with 2 (related) goals at once if I'm correct (not me who started it):
- batched inference,
- but to do that in the solver, all padding etc operations in the transform needed to be moved from the cebra.CEBRA() class to the solver.
--> see other linked issues for better understanding.
Ok, makes sense!
also upgraded this, and checking again once tests passed.
note, update to main first after #212 is merged
This error needs to be fixed in the usage.rst, most likely some "one sample left in last batch" error.
Need to add corner case when input is shorter than 2 batches and batch inference is asked.
Sol: either raise an error or go back to normal inference without telling the user as likely that memory won’t be an issue in that case
something like:
index_dataset = IndexDataset(inputs)
index_dataloader = DataLoader(index_dataset, batch_size=batch_size)
if len(index_dataloader) < 2:
raise ValueError(
f"Number of batches must be greater than 1, you can use transform without batching instead, got {len(index_dataloader)}."
)
[...]
if batch_size is not None andinputs.shape[0] > int(batch_size*2):
@CeliaBenquet @stes what else is needed for this? seems a conflict, then? tests issue?
Hey @CeliaBenquet can you resolve conflicts, then we can merge since this is internally in use!
=========================== short test summary info ============================
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR cebra/config.py - NameError: name 'get_datapath' is not defined
ERROR cebra/config.py - NameError: name 'get_datapath' is not defined
!!!!!!!!!!!!!!!!!!! Interrupted: 20 errors during collection !!!!!!!!!!!!!!!!!!!
doc error is: /home/runner/work/CEBRA/CEBRA/cebra/data/single_session.py:docstring of cebra.data.single_session.SingleSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset /home/runner/work/CEBRA/CEBRA/cebra/data/multi_session.py:docstring of cebra.data.multi_session.MultiSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset
@CeliaBenquet not sure I see your edits post review; did you push them?
At the risk of it not being perfect, lets merge this now; @CeliaBenquet can document in an issue the remaining Qs on the API design, but getting #251 merged is a priority 🦾