deepethogram
deepethogram copied to clipboard
[COLAB] FlowGenerator error in COLAB related to kornia augmentation or pytorch lightning package
Hi,
I am using colab to try to train the test data (testing_deepethogram_archive.zip) provided on Github.
During the flow_generator = flow_generator_train(cfg)
step, after 70% of epoch 0 is achieved, I get this kornia augmentation error.
ERROR
Epoch 0: 70%
46/66 [01:43<00:44, 2.24s/it, loss=0.0427, v_num=0]
IndexError Traceback (most recent call last)
[<ipython-input-19-1d7b632134f1>](https://localhost:8080/#) in <cell line: 1>()
----> 1 flow_generator = flow_generator_train(cfg)
51 frames
[/usr/local/lib/python3.10/dist-packages/deepethogram/flow_generator/train.py](https://localhost:8080/#) in flow_generator_train(cfg)
77
78 trainer = get_trainer_from_cfg(cfg, lightning_module, stopper)
---> 79 trainer.fit(lightning_module)
80 return flow_generator
81
....
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/_2d/base.py](https://localhost:8080/#) in generate_transformation_matrix(self, input, params, flags)
81 else:
82 trans_matrix_A = self.identity_matrix(in_tensor)
---> 83 trans_matrix_B = self.compute_transformation(in_tensor[to_apply], params=params, flags=flags)
84
85 if is_autocast_enabled():
IndexError: The shape of the mask [352] at index 0 does not match the shape of the indexed tensor [308, 3, 224, 224] at index 0
INSTALLED LIBRARIES I installed opencv because the colab did not run without cv2.
print('\n=================ORIGINAL DEG INSTALLATION CODE=========================\n')
!pip uninstall -y opencv-python
print('\n==========================================\n')
!pip install --upgrade deepethogram
print('\n==========================================\n')
!pip uninstall -y torchtext # this is for pytorch lightning compatibility
print('\n======================DEBUGGING Karin====================\n')
#Issue GitHub: The "gpus" argument is no longer present in the newest versions of pytorch-lightning
#You need to install an older version, for example pytorch-lightning==1.5.10
!pip install pytorch-lightning==1.5.10
print('\n======================DEBUGGING Karin: add opencv====================\n')
!pip install opencv-python
FULL ERROR
[2023-10-22 13:52:44,519] INFO [deepethogram.projects.convert_config_paths_to_absolute:1135] cwd in absolute: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/test3_deepethogram/models/231022_133620_flow_generator_train
[2023-10-22 13:52:44,525] INFO [deepethogram.projects.convert_config_paths_to_absolute:1178] after absolute: {'class_names': ['background', 'face_groom', 'body_groom', 'dig', 'scratch'], 'config_file': '/content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/project_config.yaml', 'data_path': '/content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/DATA', 'labeler': None, 'model_path': '/content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models', 'name': 'testing', 'path': '/content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram', 'pretrained_path': '/content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/pretrained_models'}
[2023-10-22 13:52:44,548] INFO [deepethogram.flow_generator.train.flow_generator_train:54] args: /usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py -f /root/.local/share/jupyter/runtime/kernel-f6ad82e8-3203-4d02-9752-8ad350edbc18.json
[2023-10-22 13:52:44,552] INFO [deepethogram.flow_generator.train.flow_generator_train:62] configuration used ~~~~~
[2023-10-22 13:52:44,563] INFO [deepethogram.flow_generator.train.flow_generator_train:63] split:
reload: true
file: null
train_val_test:
- 0.8
- 0.2
- 0.0
compute:
fp16: false
num_workers: 2
batch_size: 32
min_batch_size: 8
max_batch_size: 512
distributed: false
gpu_id: 0
dali: false
metrics_workers: 0
reload:
overwrite_cfg: false
latest: false
notes: null
log:
level: info
augs:
brightness: 0.25
contrast: 0.1
hue: 0.1
saturation: 0.1
color_p: 0.5
grayscale: 0.5
crop_size: null
resize:
- 224
- 224
dali: false
random_resize: false
pad: null
LR: 0.5
UD: 0.5
degrees: 10
normalization:
'N': 13125000
mean:
- 0.02004539966386554
- 0.03199181684407095
- 0.025961602390289447
std:
- 0.02522799020705389
- 0.05607626687605602
- 0.03893020334412448
train:
lr: 0.0001
scheduler: plateau
num_epochs: 10
steps_per_epoch:
train: 1000
val: 200
test: 20
min_lr: 5.0e-07
stopping_type: learning_rate
milestones:
- 50
- 100
- 150
- 200
- 250
- 300
weight_loss: true
patience: 3
early_stopping_begins: 0
viz_metrics: true
viz_examples: 10
reduction_factor: 0.1
loss_weight_exp: 1.0
loss_gamma: 1.0
label_smoothing: 0.05
oversampling_exp: 0.0
regularization:
style: l2_sp
alpha: 1.0e-05
beta: 0.001
flow_generator:
type: flow_generator
flow_loss: MotionNet
flow_max: 10
input_images: 11
flow_sparsity: false
smooth_weight_multiplier: 1.0
sparsity_weight: 0.0
loss: MotionNet
max: 5
n_rgb: 11
arch: TinyMotionNet
weights: pretrained
'n': 10
feature_extractor:
arch: resnet18
n_flow: 10
n_rgb: 1
project:
class_names:
- background
- face_groom
- body_groom
- dig
- scratch
config_file: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version
230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/project_config.yaml
data_path: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version
230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/DATA
labeler: null
model_path: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version
230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models
name: testing
path: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers
requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram
pretrained_path: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper
version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/pretrained_models
sequence:
filter_length: 15
run:
type: train
model: flow_generator
dir: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers
requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/231022_135244_flow_generator_train
[2023-10-22 13:52:49,197] INFO [deepethogram.flow_generator.train.flow_generator_train:67] Total trainable params: 1,951,784
[2023-10-22 13:52:51,772] INFO [deepethogram.projects.get_weightfile_from_cfg:1068] loading pretrained weights: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/pretrained_models/200221_115158_TinyMotionNet/checkpoint.pt
[2023-10-22 13:52:51,776] INFO [deepethogram.utils.load_state:341] loading from checkpoint file /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/pretrained_models/200221_115158_TinyMotionNet/checkpoint.pt...
reloading weights...
[2023-10-22 13:52:52,300] INFO [deepethogram.flow_generator.train.get_metrics:364] key metric is SSIM
[2023-10-22 13:52:52,325] INFO [deepethogram.data.augs.get_gpu_transforms:246] GPU transforms: {'train': Sequential(
(0): ToFloat()
(1): VideoSequential(
(RandomHorizontalFlip_0): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
(RandomVerticalFlip_1): RandomVerticalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
(RandomRotation_2): RandomRotation(degrees=10, p=0.5, p_batch=1.0, same_on_batch=False, resample=bilinear, align_corners=True)
(ColorJitter_3): ColorJitter(brightness=0.25, contrast=0.1, saturation=0.1, hue=0.1, p=0.5, p_batch=1.0, same_on_batch=False)
(RandomGrayscale_4): RandomGrayscale(p=0.5, p_batch=1.0, same_on_batch=False)
)
(2): NormalizeVideo()
(3): StackClipInChannels()
), 'val': Sequential(
(0): ToFloat()
(1): NormalizeVideo()
(2): StackClipInChannels()
), 'test': Sequential(
(0): ToFloat()
(1): NormalizeVideo()
(2): StackClipInChannels()
), 'denormalize': Sequential(
(0): UnstackClip()
(1): DenormalizeVideo()
)}
[2023-10-22 13:52:52,326] INFO [deepethogram.base.__init__:95] scheduler mode: min
[2023-10-22 13:52:52,452] INFO [deepethogram.losses.get_regularization_loss:204] Regularization: L2_SP. Pretrained file: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/pretrained_models/200221_115158_TinyMotionNet/checkpoint.pt alpha: 1e-05 beta: 0.001
[2023-10-22 13:52:52,507] INFO [deepethogram.flow_generator.losses.__init__:178] Using MotionNet Loss with settings: smooth_weights: [0.01, 0.02, 0.04, 0.08, 0.16] flow_sparsity: False sparsity_weight: 0.0
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/callback_connector.py:90: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
rank_zero_deprecation(
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:88: LightningDeprecationWarning: `reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed in v1.6. Please use `reload_dataloaders_every_n_epochs` in Trainer.
rank_zero_deprecation(
[2023-10-22 13:52:52,517] INFO [pytorch_lightning.utilities.distributed._info:93] GPU available: True, used: True
[2023-10-22 13:52:52,519] INFO [pytorch_lightning.utilities.distributed._info:93] TPU available: False, using: 0 TPU cores
[2023-10-22 13:52:52,523] INFO [pytorch_lightning.utilities.distributed._info:93] IPU available: False, using: 0 IPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:275: LightningDeprecationWarning: The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7. Please use the `on_exception` callback hook instead.
rank_zero_deprecation(
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:291: LightningDeprecationWarning: Base `Callback.on_train_batch_start` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.
rank_zero_deprecation(
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:291: LightningDeprecationWarning: Base `Callback.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.
rank_zero_deprecation(
[2023-10-22 13:52:52,593] INFO [pytorch_lightning.accelerators.gpu.set_nvidia_flags:59] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2023-10-22 13:52:52,621] INFO [deepethogram.base.configure_optimizers:227] learning rate: 0.0001
[2023-10-22 13:52:52,624] WARNING [pytorch_lightning.loggers.tensorboard._get_next_version:298] Missing logger folder: /content/drive/MyDrive/Research/Schneider lab/Paper/Karin paper version 230911/Reviewers requests/Fig 5 Automated behavior classification/DeepEthogram/testingdeg_deepethogram/models/231022_135244_flow_generator_train/default
[2023-10-22 13:52:52,630] INFO [pytorch_lightning.callbacks.model_summary.summarize:73]
| Name | Type | Params
--------------------------------------------
0 | model | TinyMotionNet | 2.0 M
1 | criterion | MotionNetLoss | 0
--------------------------------------------
2.0 M Trainable params
0 Non-trainable params
2.0 M Total params
7.807 Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/data_loading.py:659: UserWarning: Your `val_dataloader` has `shuffle=True`, it is strongly recommended that you turn this off for val/test/predict dataloaders.
rank_zero_warn(
Epoch 0: 70%
46/66 [01:43<00:44, 2.24s/it, loss=0.0427, v_num=0]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
[<ipython-input-19-1d7b632134f1>](https://localhost:8080/#) in <cell line: 1>()
----> 1 flow_generator = flow_generator_train(cfg)
51 frames
[/usr/local/lib/python3.10/dist-packages/deepethogram/flow_generator/train.py](https://localhost:8080/#) in flow_generator_train(cfg)
77
78 trainer = get_trainer_from_cfg(cfg, lightning_module, stopper)
---> 79 trainer.fit(lightning_module)
80 return flow_generator
81
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
738 )
739 train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
742 )
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
683 """
684 try:
--> 685 return trainer_fn(*args, **kwargs)
686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
687 except KeyboardInterrupt as exception:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
775 # TODO: ckpt_path only in v1.7
776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
778
779 assert self.state.stopped
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
1197
1198 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1199 self._dispatch()
1200
1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _dispatch(self)
1277 self.training_type_plugin.start_predicting(self)
1278 else:
-> 1279 self.training_type_plugin.start_training(self)
1280
1281 def run_stage(self):
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py](https://localhost:8080/#) in start_training(self, trainer)
200 def start_training(self, trainer: "pl.Trainer") -> None:
201 # double dispatch to initiate the training loop
--> 202 self._results = trainer.run_stage()
203
204 def start_evaluating(self, trainer: "pl.Trainer") -> None:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in run_stage(self)
1287 if self.predicting:
1288 return self._run_predict()
-> 1289 return self._run_train()
1290
1291 def _pre_training_routine(self):
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_train(self)
1317 self.fit_loop.trainer = self
1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1319 self.fit_loop.run()
1320
1321 def _run_evaluate(self) -> _EVALUATE_OUTPUT:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py](https://localhost:8080/#) in advance(self)
232
233 with self.trainer.profiler.profile("run_training_epoch"):
--> 234 self.epoch_loop.run(data_fetcher)
235
236 # the global step is manually decreased here due to backwards compatibility with existing loggers
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py](https://localhost:8080/#) in advance(self, *args, **kwargs)
191
192 with self.trainer.profiler.profile("run_training_batch"):
--> 193 batch_output = self.batch_loop.run(batch, batch_idx)
194
195 self.batch_progress.increment_processed()
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py](https://localhost:8080/#) in advance(self, batch, batch_idx)
86 if self.trainer.lightning_module.automatic_optimization:
87 optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
---> 88 outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
89 else:
90 outputs = self.manual_loop.run(split_batch, batch_idx)
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in advance(self, batch, *args, **kwargs)
213
214 def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
--> 215 result = self._run_optimization(
216 batch,
217 self._batch_idx,
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _run_optimization(self, split_batch, batch_idx, optimizer, opt_idx)
264 # gradient update with accumulated gradients
265 else:
--> 266 self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
267
268 result = closure.consume_result()
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
376
377 # model hook
--> 378 lightning_module.optimizer_step(
379 self.trainer.current_epoch,
380 batch_idx,
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/lightning.py](https://localhost:8080/#) in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1650
1651 """
-> 1652 optimizer.step(closure=optimizer_closure)
1653
1654 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py](https://localhost:8080/#) in step(self, closure, **kwargs)
162 assert trainer is not None
163 with trainer.profiler.profile(profiler_action):
--> 164 trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/accelerators/accelerator.py](https://localhost:8080/#) in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
337 """
338 model = model or self.lightning_module
--> 339 self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
340
341 def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
161 if isinstance(model, pl.LightningModule):
162 closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 163 optimizer.step(closure=closure, **kwargs)
164
165 def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
[/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
371 )
372
--> 373 out = func(*args, **kwargs)
374 self._optimizer_step_code()
375
[/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in _use_grad(self, *args, **kwargs)
74 torch.set_grad_enabled(self.defaults['differentiable'])
75 torch._dynamo.graph_break()
---> 76 ret = func(self, *args, **kwargs)
77 finally:
78 torch._dynamo.graph_break()
[/usr/local/lib/python3.10/dist-packages/torch/optim/adam.py](https://localhost:8080/#) in step(self, closure)
141 if closure is not None:
142 with torch.enable_grad():
--> 143 loss = closure()
144
145 for group in self.param_groups:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
146 consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
147 """
--> 148 closure_result = closure()
149 self._after_closure(model, optimizer, optimizer_idx)
150 return closure_result
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
158
159 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 160 self._result = self.closure(*args, **kwargs)
161 return self._result.loss
162
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in closure(self, *args, **kwargs)
140 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
141 with self._profiler.profile("training_step_and_backward"):
--> 142 step_output = self._step_fn()
143
144 if step_output.closure_loss is None:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _training_step(self, split_batch, batch_idx, opt_idx)
433 lightning_module._current_fx_name = "training_step"
434 with self.trainer.profiler.profile("training_step"):
--> 435 training_step_output = self.trainer.accelerator.training_step(step_kwargs)
436 self.trainer.training_type_plugin.post_training_step()
437
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/accelerators/accelerator.py](https://localhost:8080/#) in training_step(self, step_kwargs)
217 """
218 with self.precision_plugin.train_step_context():
--> 219 return self.training_type_plugin.training_step(*step_kwargs.values())
220
221 def post_training_step(self) -> None:
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py](https://localhost:8080/#) in training_step(self, *args, **kwargs)
211
212 def training_step(self, *args, **kwargs):
--> 213 return self.model.training_step(*args, **kwargs)
214
215 def post_training_step(self):
[/usr/local/lib/python3.10/dist-packages/deepethogram/flow_generator/train.py](https://localhost:8080/#) in training_step(self, batch, batch_idx)
180
181 def training_step(self, batch: dict, batch_idx: int):
--> 182 return self.common_step(batch, batch_idx, 'train')
183
184 def validation_step(self, batch: dict, batch_idx: int):
[/usr/local/lib/python3.10/dist-packages/deepethogram/flow_generator/train.py](https://localhost:8080/#) in common_step(self, batch, batch_idx, split)
161 """
162 # forward pass. images are returned because the forward pass runs augmentations on the gpu as well
--> 163 images, outputs = self(batch, split)
164 # actually reconstruct t0 using t1 and estimated optic flow
165 downsampled_t0, estimated_t0, flows_reshaped = self.reconstructor(images, outputs)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
[/usr/local/lib/python3.10/dist-packages/deepethogram/flow_generator/train.py](https://localhost:8080/#) in forward(self, batch, mode)
265 # lightning handles transfer to device
266 images = batch['images']
--> 267 images = self.apply_gpu_transforms(images, mode)
268
269 outputs = self.model(images)
[/usr/local/lib/python3.10/dist-packages/deepethogram/base.py](https://localhost:8080/#) in apply_gpu_transforms(self, images, mode)
214 def apply_gpu_transforms(self, images: torch.Tensor, mode: str) -> torch.Tensor:
215 with torch.no_grad():
--> 216 images = self.gpu_transforms[mode](images).detach()
217 return images
218
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py](https://localhost:8080/#) in forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input
217
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/container/video.py](https://localhost:8080/#) in forward(self, input, params, extra_args)
336 params = self._params
337
--> 338 output = self.transform_inputs(input, params, extra_args=extra_args)
339
340 return output
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/container/video.py](https://localhost:8080/#) in transform_inputs(self, input, params, extra_args)
204 input = self._input_shape_convert_in(input, frame_num)
205
--> 206 input = super().transform_inputs(input, params, extra_args=extra_args)
207
208 input = self._input_shape_convert_back(input, frame_num)
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/container/base.py](https://localhost:8080/#) in transform_inputs(self, input, params, extra_args)
196 for param in params:
197 module = self.get_submodule(param.name)
--> 198 input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
199 return input
200
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/container/ops.py](https://localhost:8080/#) in transform(cls, input, module, param, extra_args)
157 def transform(cls, input: Tensor, module: Module, param: ParamItem, extra_args: Dict[str, Any] = {}) -> Tensor:
158 if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2)):
--> 159 input = module(input, params=cls.get_instance_module_param(param), **extra_args)
160 elif isinstance(module, (K.container.ImageSequentialBase,)):
161 input = module.transform_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/base.py](https://localhost:8080/#) in forward(self, input, params, **kwargs)
208 params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs)
209
--> 210 output = self.apply_func(in_tensor, params, flags)
211 return self.transform_output_tensor(output, input_shape) if self.keepdim else output
212
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/_2d/base.py](https://localhost:8080/#) in apply_func(self, in_tensor, params, flags)
122 flags = self.flags
123
--> 124 trans_matrix = self.generate_transformation_matrix(in_tensor, params, flags)
125 output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
126 self._transform_matrix = trans_matrix
[/usr/local/lib/python3.10/dist-packages/kornia/augmentation/_2d/base.py](https://localhost:8080/#) in generate_transformation_matrix(self, input, params, flags)
81 else:
82 trans_matrix_A = self.identity_matrix(in_tensor)
---> 83 trans_matrix_B = self.compute_transformation(in_tensor[to_apply], params=params, flags=flags)
84
85 if is_autocast_enabled():
IndexError: The shape of the mask [352] at index 0 does not match the shape of the indexed tensor [308, 3, 224, 224] at index 0