learn2learn
learn2learn copied to clipboard
LightningPrototypicalNetworks example fails
Hi there, why this examples is not working? why cross_entropy fails?
import learn2learn as l2l
import pytorch_lightning as pl
tasksets = l2l.vision.benchmarks.get_tasksets('omniglot', root="D:/datasets/omniglot/")
features = l2l.vision.models.OmniglotCNN()
protonet = LightningPrototypicalNetworks(features)
episodic_data = EpisodicBatcher(tasksets.train, tasksets.validation, tasksets.test)
trainer = pl.Trainer()
trainer.fit(protonet, episodic_data)
cell ouput:
in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
648 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
649 else:
--> 650 return trainer_fn(*args, **kwargs)
651 # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
652 except KeyboardInterrupt as exception:
...
3012 if size_average is not None or reduce is not None:
3013 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3014 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
IndexError: Target 1 is out of bounds.