transformers
transformers copied to clipboard
Llama Model throwing "RuntimeError: expected scalar type BFloat16 but found Float" when using torch.compile and AMP together
System Info
transformers 4.41.0 torch 2.3.0 GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3
Who can help?
No response
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
import torch
from transformers import LlamaConfig, LlamaForCausalLM, AdamW, AutoModelForCausalLM, GPT2Config
from torch.cuda.amp import autocast, GradScaler
# Configure the model
config = LlamaConfig(
num_attention_heads=6,
num_hidden_layers=6,
hidden_size=384,
intermediate_size=1536, # Typically 4 * hidden_size
vocab_size=30522, # Standard vocabulary size
max_position_embeddings=1024,
)
# config = GPT2Config(
# n_embd=384,
# n_head=6,
# n_layer=6,
# n_positions=1024,
# n_ctx=1024,
# n_vocab=30522,
# )
# Initialize the model
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager").to('cuda')
# Compile the model (Torch 2.0 and above)
model = torch.compile(model)
# Create dummy data
batch_size = 8
sequence_length = 1024
dummy_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
dummy_labels = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()
# Set the model to training mode
model.train()
# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
with autocast(dtype=torch.bfloat16, enabled=True):
# Forward pass
outputs = model(input_ids=dummy_input_ids, labels=dummy_labels)
loss = outputs.loss
# Backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
print("Training complete.")
Expected behavior
Running the code snippet above gives me the following error
{
"name": "RuntimeError",
"message": "expected scalar type BFloat16 but found Float",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 54
51 loss = outputs.loss
53 # Backward pass
---> 54 scaler.scale(loss).backward()
55 scaler.step(optimizer)
56 scaler.update()
File ~/anaconda3/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
515 if has_torch_function_unary(self):
516 return handle_torch_function(
517 Tensor.backward,
518 (self,),
(...)
523 inputs=inputs,
524 )
--> 525 torch.autograd.backward(
526 self, gradient, retain_graph, create_graph, inputs=inputs
527 )
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
262 retain_graph = create_graph
264 # The reason we repeat the same comment below is that
265 # some Python versions print out the first line of a multi-line function
266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
268 tensors,
269 grad_tensors_,
270 retain_graph,
271 create_graph,
272 inputs,
273 allow_unreachable=True,
274 accumulate_grad=True,
275 )
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
742 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
743 try:
--> 744 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
745 t_outputs, *args, **kwargs
746 ) # Calls into the C++ engine to run the backward pass
747 finally:
748 if attach_logging_hooks:
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
295 raise RuntimeError(
296 \"Implementing both 'backward' and 'vjp' for a custom \"
297 \"Function is not allowed. You should only implement one \"
298 \"of them.\"
299 )
300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:882, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
880 out = CompiledFunctionBackward.apply(*all_args)
881 else:
--> 882 out = call_compiled_backward()
884 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
885 if CompiledFunction.maybe_subclass_metadata is not None:
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
824 with tracing(saved_context), context(), track_graph_compiling(
825 aot_config, \"backward\"
826 ):
827 CompiledFunction.compiled_bw = aot_config.bw_compiler(
828 bw_module, placeholder_list
829 )
--> 831 out = call_func_at_runtime_with_args(
832 CompiledFunction.compiled_bw,
833 all_args,
834 steal_args=True,
835 disable_amp=disable_amp,
836 )
838 out = functionalized_rng_runtime_epilogue(
839 CompiledFunction.metadata, out
840 )
841 return tuple(out)
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:113, in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
111 with context():
112 if hasattr(f, \"_boxed_call\"):
--> 113 out = normalize_as_list(f(args))
114 else:
115 # TODO: Please remove soon
116 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
117 warnings.warn(
118 \"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. \"
119 \"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \"
120 \"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\"
121 )
File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
449 prior = set_eval_frame(callback)
450 try:
--> 451 return fn(*args, **kwargs)
452 finally:
453 set_eval_frame(prior)
File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
34 @functools.wraps(fn)
35 def inner(*args, **kwargs):
---> 36 return fn(*args, **kwargs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:906, in CompiledFxGraph.__call__(self, inputs)
905 def __call__(self, inputs: List[Any]) -> Any:
--> 906 return self.get_current_callable()(inputs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:784, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
782 def run(new_inputs):
783 copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 784 return model(new_inputs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:934, in _run_from_cache(compiled_graph, inputs)
926 assert compiled_graph.artifact_path
927 compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
928 compiled_graph.cache_key,
929 compiled_graph.artifact_path,
930 compiled_graph.cache_linemap,
931 compiled_graph.constants,
932 ).call
--> 934 return compiled_graph.compiled_artifact(inputs)
File /tmp/torchinductor_zcai75/wq/cwqm67koqia7gthn65wgmhppfzrfyheocl4px7fecurpkfigigfs.py:1751, in call(args)
1749 buf39 = reinterpret_tensor(buf34, (48, 64, 1024), (65536, 1024, 1), 0); del buf34 # reuse
1750 # Source Nodes: [], Original ATen: [aten.bmm]
-> 1751 extern_kernels.bmm(permute_103, reinterpret_tensor(buf38, (48, 1024, 1024), (1048576, 1024, 1), 0), out=buf39)
1752 del permute_103
1753 buf41 = empty_strided_cuda((8, 6, 1024, 64), (393216, 65536, 64, 1), torch.bfloat16)
RuntimeError: expected scalar type BFloat16 but found Float"
}
This problem does not seem to happen for a GPT2 model. If I initialize the GPT2Config instead of LlamaConfig in the commented code in the script, there is no such error.