I am trying to train a multimer model resuming from jax weights and get the following error:
Traceback (most recent call last):
File "/work/10110/abhinav22/ls6/openfold/train_openfold.py", line 706, in
main(args)
File "/work/10110/abhinav22/ls6/openfold/train_openfold.py", line 455, in main
trainer.fit(
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
return function(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
results = self._run_stage()
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
self.fit_loop.run()
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
self.advance()
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
self._optimizer_step(batch_idx, closure)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1308, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 270, in optimizer_step
optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 129, in optimizer_step
closure_result = closure()
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in call
self._result = self.closure(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 129, in closure
step_output = self._step_fn()
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 317, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 389, in training_step
return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in call
wrapper_output = wrapper_module(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1822, in forward
loss = self.module(*inputs, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(args, **kwargs)
File "/work/10110/abhinav22/ls6/src/miniforge/envs/openfold_cuda12/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
out = method(_args, **_kwargs)
File "/work/10110/abhinav22/ls6/openfold/train_openfold.py", line 110, in training_step
batch = multi_chain_permutation_align(out=outputs,
File "/work/10110/abhinav22/ls6/openfold/openfold/utils/multi_chain_permutation.py", line 533, in multi_chain_permutation_align
align, per_asym_residue_index = compute_permutation_alignment(out=out,
File "/work/10110/abhinav22/ls6/openfold/openfold/utils/multi_chain_permutation.py", line 485, in compute_permutation_alignment
r, x = calculate_optimal_transform(true_ca_poses,
File "/work/10110/abhinav22/ls6/openfold/openfold/utils/multi_chain_permutation.py", line 422, in calculate_optimal_transform
r, x = get_optimal_transform(
File "/work/10110/abhinav22/ls6/openfold/openfold/utils/multi_chain_permutation.py", line 105, in get_optimal_transform
r = kabsch_rotation(src_atoms, tgt_atoms)
File "/work/10110/abhinav22/ls6/openfold/openfold/utils/multi_chain_permutation.py", line 52, in kabsch_rotation
u, _, vt = torch.linalg.svd(torch.matmul(P.to(torch.float32).T,
RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling cusolverDnCreate(handle). If you keep seeing this error, you may use torch.backends.cuda.preferred_linalg_library() to try linear algebra operators with other supported backends. See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library
Any ideas on how to fix this? This is on a slurm node with 3 A100 40GB GPUs.