CEBRA icon indicating copy to clipboard operation
CEBRA copied to clipboard

Batched inference CEBRA & padding at the `Solver` level

Open CeliaBenquet opened this issue 1 year ago • 7 comments

fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746

fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/624 fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/637 fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/594

CeliaBenquet avatar Aug 23 '24 12:08 CeliaBenquet

@stes @MMathisLab, if you have time to review this that would be great :)

CeliaBenquet avatar Sep 18 '24 09:09 CeliaBenquet

@CeliaBenquet can you solve the conflicts, then I think fine to merge!

MMathisLab avatar Sep 18 '24 09:09 MMathisLab

@MMathisLab there's been big code changes / refactoring since @stes's last review, so I would be more confident about merging after an "in-depth" reviewing, but your call :)

CeliaBenquet avatar Sep 18 '24 11:09 CeliaBenquet

reviewing now

stes avatar Sep 18 '24 12:09 stes

re your comment on changes not related to the batched inference, it is because the PR was started with 2 (related) goals at once if I'm correct (not me who started it):

  • batched inference,
  • but to do that in the solver, all padding etc operations in the transform needed to be moved from the cebra.CEBRA() class to the solver.

--> see other linked issues for better understanding.

CeliaBenquet avatar Sep 18 '24 13:09 CeliaBenquet

Ok, makes sense!

stes avatar Sep 18 '24 20:09 stes

also upgraded this, and checking again once tests passed.

stes avatar Oct 20 '24 15:10 stes

note, update to main first after #212 is merged

stes avatar Jan 21 '25 23:01 stes

This error needs to be fixed in the usage.rst, most likely some "one sample left in last batch" error.

image

stes avatar Jan 21 '25 23:01 stes

Need to add corner case when input is shorter than 2 batches and batch inference is asked.

Sol: either raise an error or go back to normal inference without telling the user as likely that memory won’t be an issue in that case

something like:


index_dataset = IndexDataset(inputs)
    index_dataloader = DataLoader(index_dataset, batch_size=batch_size)

    if len(index_dataloader) < 2:
        raise ValueError(
            f"Number of batches must be greater than 1, you can use transform without batching instead, got {len(index_dataloader)}."
        )

[...] 
        if batch_size is not None andinputs.shape[0] > int(batch_size*2):

CeliaBenquet avatar Jan 22 '25 08:01 CeliaBenquet

@CeliaBenquet @stes what else is needed for this? seems a conflict, then? tests issue?

MMathisLab avatar Feb 18 '25 09:02 MMathisLab

Hey @CeliaBenquet can you resolve conflicts, then we can merge since this is internally in use!

MMathisLab avatar Apr 23 '25 05:04 MMathisLab

=========================== short test summary info ============================
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_benchmark.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_data_helper.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_datasets.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_distributions.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_integration_train.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR tests/test_solver.py - NameError: name 'get_datapath' is not defined
ERROR cebra/config.py - NameError: name 'get_datapath' is not defined
ERROR cebra/config.py - NameError: name 'get_datapath' is not defined
!!!!!!!!!!!!!!!!!!! Interrupted: 20 errors during collection !!!!!!!!!!!!!!!!!!!

MMathisLab avatar Apr 23 '25 16:04 MMathisLab

doc error is: /home/runner/work/CEBRA/CEBRA/cebra/data/single_session.py:docstring of cebra.data.single_session.SingleSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset /home/runner/work/CEBRA/CEBRA/cebra/data/multi_session.py:docstring of cebra.data.multi_session.MultiSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset

MMathisLab avatar Apr 24 '25 09:04 MMathisLab

@CeliaBenquet not sure I see your edits post review; did you push them?

MMathisLab avatar Apr 24 '25 15:04 MMathisLab

At the risk of it not being perfect, lets merge this now; @CeliaBenquet can document in an issue the remaining Qs on the API design, but getting #251 merged is a priority 🦾

MMathisLab avatar May 23 '25 13:05 MMathisLab