Traceback (most recent call last):
File "/root/node03-nfs/model/baseline_model/train_stage_2.py", line 773, in
main(config)
File "/root/node03-nfs/model/baseline_model/train_stage_2.py", line 602, in main
model_pred = net(
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
return model_forward(*args, **kwargs)
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/root/node03-nfs/model/baseline_model/train_stage_2.py", line 96, in forward
model_pred = self.denoising_unet(
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/node03-nfs/model/baseline_model/src/models/unet_3d.py", line 514, in forward
sample, res_samples = downsample_block(
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/node03-nfs/model/baseline_model/src/models/unet_3d_blocks.py", line 451, in forward
hidden_states = torch.utils.checkpoint.checkpoint(
File "/root/anaconda3/envs/animate/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 246, in checkpoint
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
ValueError: Unexpected keyword arguments: encoder_hidden_states,self_attention_additional_feats,mode
Turn off gradient checkpointing
#133