open_clip
open_clip copied to clipboard
Error when using torchcompile option for CLIP training
Hello,
While I attempt to apply torchcompile option for training CLIP ViT-B-32 model, I got some error. Below is the script to run training.
torchrun --nproc_per_node 16 -m training.main --save-frequency 1 --zeroshot-frequency 1 --report-to tensorboard --train-data={data_dir} --csv-img-key filepath --csv-caption-key title --imagenet-val={imagenet val dir} --workers=8 --model ViT-B-32 --precision amp_bf16 --workers 4 --csv-separator "," --local-loss --gather-with-grad --aug-cfg scale='(0.5, 1.0)' --name test--accum-freq 4 --grad-checkpointing --torchcompile
And I got the below error message.
How can I fix this issue?
Note that my pytorch version is 2.1.0 and no error occurs when I runs above script without --torchcompile
option.
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/open_clip/src/training/main.py", line 508, in <module>
main(sys.argv[1:])
File "/workspace/open_clip/src/training/main.py", line 436, in main
train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
File "/workspace/open_clip/src/training/train.py", line 117, in train_one_epoch
model_out = model(images, texts)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in catch_errors
return hijacked_callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1167, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
return tx.inline_user_function_return(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
return tx.inline_user_function_return(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1123, in call_function
p_args, _, example_value = self.create_wrapped_node(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1025, in create_wrapped_node
) = speculate_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 203, in speculate_subgraph
f"to trace function `{f.get_name()}` into a single graph. This means "
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'
from user code:
File "/workspace/open_clip/src/open_clip/model.py", line 293, in forward
image_features = self.encode_image(image, normalize=True) if image is not None else None
File "/workspace/open_clip/src/open_clip/model.py", line 266, in encode_image
features = self.visual(image)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/open_clip/src/open_clip/transformer.py", line 516, in forward
x = self.transformer(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/open_clip/src/open_clip/transformer.py", line 322, in forward
x = checkpoint(r, x, None, None, attn_mask)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
@kkjh0723 I think it might break with gradient checkpointing? not sure there is a workaround, possibly maybe using non reentrant mode?
I got the same error trying to run both --grad-checkpointing
and --torchcompile
, but since pytorch 2.1.0 --torchcompile
now works with --accum-freq
> 1 as the next best option.
@EIFY did you try forcing the non reentrant checkpointing? could look to change the default if that works...
@rwightman No I haven't tried that.
In that regard, the good news is that https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/transformer.py#L320-L322 https://github.com/pytorch/pytorch/issues/79887 is now fixed and we should be able to do e.g.
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
The bad news is that other than that grad_checkpointing
is either delegated to the vision/text trunks w/o argument support
https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/model.py#L260-L263
or not supported at all:
https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/modified_resnet.py#L161-L164
So fairly involved changes would be necessary. I will try doing the easy part and see if it at least gets past that when I get a chance.
@rwightman OK so it turned out that use_reentrant=False
doesn't help. It still breaks at the same point:
[2023-11-08 12:56:29,383] [0/0] torch._utils_internal: [INFO] CompilationMetrics(frame_key='1', co_name='forward', co_filename='/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py', co_firstlineno=256, cache_size=0, guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, fail_reason="'NNModuleVariable' object has no attribute 'get_name'")
Traceback (most recent call last):
(...)
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'
from user code:
File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 274, in forward
image_features = dim_scale_img * self.encode_image(image, normalize=self.normalize) if image is not None else None
File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 239, in encode_image
features = self.visual(image)
File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 486, in forward
x = self.transformer(x)
File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 319, in forward
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
Is there any update on this? I am facing the same issue.