torchtitan
torchtitan copied to clipboard
FSDP + SP does not work with --compile
FSDP + SP works fine when compile is off, but got the following error when compile is on:
error log
SP=2 ./run_llama_train.sh
+ TRAINER_DIR=/home/lty/local/torchtrain
+ MODEL=llama
+ MODEL_CONF=debugmodel
+ NGPU=8
+ PP=1
+ SP=2
+ DP=-1
+ LOG_RANK=0
+ CHECKPOINT_FOLDER=
+ CHECKPOINT_INTERVAL=5
+ torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --steps 10 --model llama --model_conf debugmodel --pp_degree 1 --sp_degree 2 --dp_degree -1 --compile --checkpoint-folder= --checkpoint-interval=5
W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717
W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 *****************************************
W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 *****************************************
[rank0]:2024-02-15 17:38:20,132 - torchtrain.parallelisms - INFO - Building 2-D device mesh with ('dp', 'sp'), [4, 2]
[rank0]:2024-02-15 17:38:28,308 - root - INFO - Building llama
[rank0]:2024-02-15 17:38:28,325 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model
[rank0]:2024-02-15 17:38:28,325 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2
[rank0]:2024-02-15 17:38:31,662 - root - INFO - Model fully initialized via reset_params
[rank0]:2024-02-15 17:38:31,662 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768)
[rank0]:2024-02-15 17:38:31,662 - root - INFO - Model llama debugmodel size: 18,089,216 total parameters
[rank0]:2024-02-15 17:38:31,663 - root - INFO - GPU memory usage: NVIDIA PG509-210 (0): 79.1537 GB capacity, 0.0 GB in-use, 0.0% in-use
[rank0]:NCCL version 2.19.3+cuda12.0
[rank0]:2024-02-15 17:38:36,274 - root - INFO - Applied Sequence Parallelism to the model...
[rank0]:2024-02-15 17:38:36,575 - root - INFO - Applied FSDP to the model...
[rank0]:2024-02-15 17:38:36,579 - root - INFO - Gradient scaling not enabled.
[rank0]:2024-02-15 17:38:36,579 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./torchtrain/outputs/tb/20240215-1738.
[rank0]:2024-02-15 17:38:36,580 - root - INFO - Compiling model llama with torch.compile...
[rank0]:2024-02-15 17:38:40,957 - root - INFO - Profiling active. Traces will be saved at ./torchtrain/outputs/profiling/traces
[rank0]:[rank0]:W0215 17:38:41.362000 139938524181632 torch/_logging/_internal.py:873 [0/0] Profiler function will be ignored
[rank0]:/home/lty/pytorch/torch/_inductor/lowering.py:1704: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 349, in
[rank0]:[rank0]: main(args)
[rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 179, in main
[rank0]:[rank0]: pred = model(input_ids)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/eval_frame.py", line 455, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward
[rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 482, in forward
[rank0]:[rank0]: def forward(self, tokens: torch.Tensor):
[rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 498, in torch_dynamo_resume_in_forward_at_493
[rank0]:[rank0]: h = layer(h, freqs_cis)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward
[rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 912, in catch_errors
[rank0]:[rank0]: return callback(frame, cache_entry, hooks, frame_state, skip=1)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 777, in _convert_frame
[rank0]:[rank0]: result = inner_convert(
[rank0]:[rank0]: ^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert
[rank0]:[rank0]: return _compile(
[rank0]:[rank0]: ^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/.conda/envs/pytorch-3.11/lib/python3.11/contextlib.py", line 81, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 669, in _compile
[rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 542, in compile_inner
[rank0]:[rank0]: out_code = transform_code_object(code, transform)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
[rank0]:[rank0]: transformations(instructions, code_options)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 163, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 507, in transform
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2130, in run
[rank0]:[rank0]: super().run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX
[rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 734, in call_function
[rank0]:[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1392, in call_function
[rank0]:[rank0]: ) = self.create_wrapped_node(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1204, in create_wrapped_node
[rank0]:[rank0]: ) = speculate_subgraph(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 396, in speculate_subgraph
[rank0]:[rank0]: output = f.call_function(tx, args, sub_kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function
[rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function
[rank0]:[rank0]: return tx.inline_user_function_return(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return
[rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
[rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX
[rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function
[rank0]:[rank0]: return tx.inline_user_function_return(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return
[rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
[rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL
[rank0]:[rank0]: self.call_function(fn, args, kwargs)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function
[rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function
[rank0]:[rank0]: return tx.inline_user_function_return(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return
[rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
[rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL
[rank0]:[rank0]: self.call_function(fn, args, kwargs)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
[rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function
[rank0]:[rank0]: return tx.inline_user_function_return(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return
[rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
[rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL
[rank0]:[rank0]: self.call_function(fn, args, kwargs)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function
[rank0]:[rank0]: return super().call_function(tx, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function
[rank0]:[rank0]: return tx.inline_user_function_return(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return
[rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call
[rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: ^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[rank0]:[rank0]: return inner_fn(self, inst)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL
[rank0]:[rank0]: self.call_function(fn, args, kwargs)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function
[rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/misc.py", line 547, in call_function
[rank0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 388, in call_method
[rank0]:[rank0]: result = handler_method(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 730, in method_redistribute
[rank0]:[rank0]: return wrap_fx_proxy(
[rank0]:[rank0]: ^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1273, in wrap_fx_proxy
[rank0]:[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1358, in wrap_fx_proxy_cls
[rank0]:[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1683, in get_fake_value
[rank0]:[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1629, in get_fake_value
[rank0]:[rank0]: ret_val = wrap_fake_exception(
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1165, in wrap_fake_exception
[rank0]:[rank0]: return fn()
[rank0]:[rank0]: ^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1630, in
[rank0]:[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1750, in run_node
[rank0]:[rank0]: raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1729, in run_node
[rank0]:[rank0]: return node.target(*args, **kwargs)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 723, in redistribute_fn_with_prim_types
[rank0]:[rank0]: return x.redistribute(*args_as_value, **kwargs_as_value)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/api.py", line 467, in redistribute
[rank0]:[rank0]: return Redistribute.apply(self, device_mesh, placements)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/autograd/function.py", line 572, in apply
[rank0]:[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 263, in forward
[rank0]:[rank0]: output = redistribute_local_tensor(local_tensor, current_spec, target_spec)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 164, in redistribute_local_tensor
[rank0]:[rank0]: transform_infos = _gen_transform_infos(current_spec, target_spec)
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 441, in __hash__
[rank0]:[rank0]: self._hash = self._hash_impl()
[rank0]:[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 424, in _hash_impl
[rank0]:[rank0]: return hash(
[rank0]:[rank0]: ^^^^^
[rank0]:[rank0]: File "/home/lty/pytorch/torch/__init__.py", line 309, in __hash__
[rank0]:[rank0]: raise TypeError("unhashable type: non-singleton SymInt")
[rank0]:[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function .redistribute_fn_with_prim_types at 0x7f45431c1b20>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('sp',)), placements=(Shard(dim=0),)),), **{}):
[rank0]:[rank0]: unhashable type: non-singleton SymInt
[rank0]:
[rank0]:[rank0]: from user code:
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
[rank0]:[rank0]: return self.checkpoint_fn( # type: ignore[misc]
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 413, in forward
[rank0]:[rank0]: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1568, in _call_impl
[rank0]:[rank0]: args_result = hook(self, args)
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 323, in
[rank0]:[rank0]: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
[rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 316, in _prepare_input_fn
[rank0]:[rank0]: dt_inp = dt_inp.redistribute(placements=(desired_layout,))
[rank0]:
[rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]:
[rank0]:
[rank0]:[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]:[rank0]: import torch._dynamo
[rank0]:[rank0]: torch._dynamo.config.suppress_errors = True
[rank0]:
W0215 17:39:06.601000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321633 closing signal SIGTERM
W0215 17:39:06.602000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321634 closing signal SIGTERM
W0215 17:39:06.603000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321636 closing signal SIGTERM
W0215 17:39:06.604000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321637 closing signal SIGTERM
W0215 17:39:06.605000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321638 closing signal SIGTERM
W0215 17:39:06.606000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321639 closing signal SIGTERM
W0215 17:39:06.608000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321641 closing signal SIGTERM
E0215 17:39:09.856000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:669 failed (exitcode: 1) local_rank: 0 (pid: 2321629) of binary: /home/lty/.conda/envs/pytorch-3.11/bin/python
Traceback (most recent call last):
File "/home/lty/.conda/envs/pytorch-3.11/bin/torchrun", line 33, in
sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lty/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/lty/pytorch/torch/distributed/run.py", line 834, in main
run(args)
File "/home/lty/pytorch/torch/distributed/run.py", line 825, in run
elastic_launch(
File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 137, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 271, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-02-15_17:39:06
host : devgpu051.cln3.facebook.com
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 2321629)
error_file:
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Yeah this is sth still not working atm due to we compile out of fsdp wrapping and it triggered some issues I think:
- this issue specifically is about dynamic shapes, so after graph break we'll hit dynamic shape for each subgraph, which is not ideal.
- if I turn
dynamic=False
, then hitting new issue that dynamo incorrectly trace into DTensor somewhere
cc @bdhirsh we probably need to study the dynamic shape issue if we want 2D parallelism work with torch.compile as "default" setting
@bdhirsh steps to repro:
- after set up the repo, on a devgpu with 8 GPUs, change SP degree to 2 or 4. https://github.com/pytorch-labs/torchtrain/blob/main/run_llama_train.sh#L15
- ./run_llama_train.sh
Should be able to hit the issues
Repining here, I am seeing:
[rank0]:WARNING: Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:NCCL version 2.20.5+cuda12.4
[rank0]:/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_inductor/lowering.py:1789: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "/home/drisspg/meta/torchtrain/train.py", line 361, in <module>
[rank0]:[rank0]: main(config)
[rank0]:[rank0]: File "/home/drisspg/meta/torchtrain/train.py", line 247, in main
[rank0]:[rank0]: pred = model(input_ids)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 390, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 857, in forward
[rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/meta/torchtrain/torchtrain/models/llama/model.py", line 504, in forward
[rank0]:[rank0]: h, freqs_cis = self.embeddings(tokens)
[rank0]:[rank0]: File "/home/drisspg/meta/torchtrain/torchtrain/models/llama/model.py", line 515, in torch_dynamo_resume_in_forward_at_504
[rank0]:[rank0]: h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 279, in __torch_dispatch__
[rank0]:[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 229, in dispatch
[rank0]:[rank0]: return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 368, in wrap
[rank0]:[rank0]: return dtensor.DTensor(
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 229, in __new__
[rank0]:[rank0]: r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 939, in catch_errors
[rank0]:[rank0]: return callback(frame, cache_entry, hooks, frame_state, skip=1)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
[rank0]:[rank0]: result = inner_convert(
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
[rank0]:[rank0]: return _compile(
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 713, in _compile
[rank0]:[rank0]: raise InternalTorchDynamoError(str(e)).with_traceback(
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile
[rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 264, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 541, in compile_inner
[rank0]:[rank0]: out_code = transform_code_object(code, transform)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
[rank0]:[rank0]: transformations(instructions, code_options)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 503, in transform
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2152, in run
[rank0]:[rank0]: super().run()
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 850, in run
[rank0]:[rank0]: while self.step():
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 764, in step
[rank0]:[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 919, in STORE_FAST
[rank0]:[rank0]: loaded_vt.set_name_hint(name)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 91, in realize_and_forward
[rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 57, in realize
[rank0]:[rank0]: self._cache.realize()
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 24, in realize
[rank0]:[rank0]: self.vt = VariableBuilder(tx, self.source)(self.value)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 274, in __call__
[rank0]:[rank0]: vt = self._wrap(value)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 424, in _wrap
[rank0]:[rank0]: return self.wrap_tensor(value)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1047, in wrap_tensor
[rank0]:[rank0]: self.assert_not_wrapped_by_this_graph(value)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 978, in assert_not_wrapped_by_this_graph
[rank0]:[rank0]: if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 123, in is_fake
[rank0]:[rank0]: attrs, _ = type(x).__tensor_flatten__(x)
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 256, in __tensor_flatten__
[rank0]:[rank0]: return ["_local_tensor"], (self._spec, self.requires_grad)
[rank0]:[rank0]: torch._dynamo.exc.InternalTorchDynamoError: 'DTensor' object has no attribute '_spec'
[rank0]:
[rank0]:[rank0]: from user code:
[rank0]:[rank0]: File "/home/drisspg/miniconda3/envs/torchtrain/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 229, in torch_dynamo_resume_in___new___at_229
[rank0]:[rank0]: r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
[rank0]:
[rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]:
[rank0]:
[rank0]:[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]:[rank0]: import torch._dynamo
[rank0]:[rank0]: torch._dynamo.config.suppress_errors = True
cc @bdhirsh
The stack here: https://github.com/pytorch/pytorch/pull/123347 looks like it's finally enough to get the torchtrain repro working, with these change:
diff --git a/train.py b/train.py
index 849ae78..171842a 100644
--- a/train.py
+++ b/train.py
@@ -221,7 +221,7 @@ def main(job_config: JobConfig):
True
)
logger.info("Compiling model with torch.compile")
- model = torch.compile(model)
+ model = torch.compile(model, backend='inductor', dynamic=False)
train_state = TrainState()
diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml
index c84407c..88192fc 100644
--- a/train_configs/debug_model.toml
+++ b/train_configs/debug_model.toml
@@ -37,10 +37,10 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
-tensor_parallel_degree = 1
+tensor_parallel_degree = 2
pipeline_parallel_degree = 1
fp8_linear = ""
-compile = false
+compile = true
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)
[activation_checkpoint]
I had to turn off dynamic shapes - I spent some time fixing a few dynamic shapes issue with DTensor in this PR: https://github.com/pytorch/pytorch/pull/123349, but there are more. So maybe for now, we can run all of our torch.compile
testing (e.g. with Float8, cc @wanchaol @drisspg @vkuzo ) with dynamic=False
, and kick the can some more on dynamic shapes (hopefully I'll have more time to keep looking at this).
closing as #268 landed -- we are using per-TransformerBlock compilation.