sam3
sam3 copied to clipboard
Reduced the resolution in the the hydra config from 1008, for fine-tuning, to avoid OOM error. Now I only get assertion errors.
I have reduced the resolution from 1008 to multiple choices, such as 512, 504, 490 and many more choices, but I keep getting errors
The below is the traceback I get for resolution=448 and min_size=448:
[rank0]:[W1217 15:56:43.552433737 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1217 15:56:45.377000 539 torch/multiprocessing/spawn.py:169] Terminating process 550 via signal SIGTERM
Traceback (most recent call last):
File "/kaggle/working/sam3/sam3/train/train.py", line 339, in <module>
main(args)
File "/kaggle/working/sam3/sam3/train/train.py", line 310, in main
single_node_runner(cfg, main_port)
File "/kaggle/working/sam3/sam3/train/train.py", line 78, in single_node_runner
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 296, in start_processes
while not context.join():
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 215, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
fn(i, *args)
File "/kaggle/working/sam3/sam3/train/train.py", line 58, in single_proc_run
trainer.run()
File "/kaggle/working/sam3/sam3/train/trainer.py", line 567, in run
self.run_train()
File "/kaggle/working/sam3/sam3/train/trainer.py", line 588, in run_train
outs = self.train_epoch(dataloader)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/train/trainer.py", line 809, in train_epoch
self._run_step(batch, phase, loss_mts, extra_loss_mts)
File "/kaggle/working/sam3/sam3/train/trainer.py", line 946, in _run_step
loss_dict, batch_size, extra_losses = self._step(
^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/train/trainer.py", line 501, in _step
find_stages = model(batch)
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/distributed.py", line 1637, in forward
else self._run_ddp_forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/distributed.py", line 1464, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/sam3_image.py", line 533, in forward
backbone_out.update(self.backbone.forward_image(input.img_batch))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vl_combiner.py", line 79, in forward_image
return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/act_ckpt_utils.py", line 86, in act_ckpt_wrapper
ret = module(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vl_combiner.py", line 86, in _forward_image_no_act_ckpt
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/necks.py", line 108, in forward
xs = self.trunk(tensor_list)
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 840, in forward
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py", line 495, in checkpoint
ret = function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 607, in forward
x = self.ls1(self.attn(x))
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 487, in forward
q, k = self._apply_rope(q, k)
^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 464, in _apply_rope
return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 80, in apply_rotary_enc
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/kaggle/working/sam3/sam3/model/vitdet.py", line 63, in reshape_for_broadcast
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
What are the safe choices for this issue?