avalanche
avalanche copied to clipboard
make_train_dataloader discards custom collate function passed as kwarg
Describe the bug Calling
cl_strategy.train(
experience,
eval_streams=[val_exp],
num_workers=4,
collate_fn=my_custom_collate,
)
should respect all of the keyword arguments I pass in. In this case, my_custom_collate is discarded.
To Reproduce For debugging, I define a custom strategy to examine what is passed into the dataloader. The make_train_dataloader function is lifted as it is from the 0.4.0 implementation
(Please note that I set breakpoints with pdb)
class CustomNaiveStrategy(Naive):
def make_train_dataloader(
self,
num_workers=0,
shuffle=True,
pin_memory=None,
persistent_workers=False,
drop_last=False,
**kwargs
):
assert self.adapted_dataset is not None
# fmt:off
import pdb; pdb.set_trace();
# fmt:on
other_dataloader_args = self._obtain_common_dataloader_parameters(
batch_size=self.train_mb_size,
num_workers=num_workers,
shuffle=shuffle,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
drop_last=drop_last,
)
# fmt:off
pdb.set_trace()
# fmt:on
if "ffcv_args" in kwargs:
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]
self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
)
Expected behavior
other_dataloader_args should obey the kwargs and pass my_custom_collate along
Screenshots
In the screenshot above p kwargs shows the custom collate function, but that does not show up in other_dataloader_args which is what is passed onto TaskBalancedDataLoader
Additional context
I cannot immediately think of why something like
other_dataloader_args.update(kwargs) is a poor idea, would love to hear thoughts.
This was supposedly fixed in #1089, or at least this is mentioned there:
Dataloading in strategies now checks if the dataset has a "collate_fn" function and uses that unless one is specified through kwargs (which takes precedence).
But, my experience above doesn't align with it. Either way, #1089 seems relevant to the conversation.
This is definitely a bug. Can you submit a PR that properly adds collate_fn to other_dataloader_args? This should be the only needed change.
Should this fix be done through updating _obtain_common_dataloader_parameters? Or is there another Avalanche style way of doing this? The hotfix of other_dataloader_args.update(kwargs) doesn't seem very Avalanche-y (but maybe i'm wrong!!)
I will also write a test to check whether kwarg collate takes precedence over dataset collate.
Feel free to assign to me, thanks
I think updating _obtain_common_dataloader_parameters is the best way.
Was this fixed in the meantime?