transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Llama Model throwing "RuntimeError: expected scalar type BFloat16 but found Float" when using torch.compile and AMP together

Open JackCai1206 opened this issue 9 months ago • 7 comments

System Info

transformers 4.41.0 torch 2.3.0 GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3

Who can help?

No response

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

import torch
from transformers import LlamaConfig, LlamaForCausalLM, AdamW, AutoModelForCausalLM, GPT2Config
from torch.cuda.amp import autocast, GradScaler

# Configure the model
config = LlamaConfig(
    num_attention_heads=6,
    num_hidden_layers=6,
    hidden_size=384,
    intermediate_size=1536,  # Typically 4 * hidden_size
    vocab_size=30522,        # Standard vocabulary size
    max_position_embeddings=1024,
)

# config = GPT2Config(
#     n_embd=384,
#     n_head=6,
#     n_layer=6,
#     n_positions=1024,
#     n_ctx=1024,
#     n_vocab=30522,
# )

# Initialize the model
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager").to('cuda')

# Compile the model (Torch 2.0 and above)
model = torch.compile(model)


# Create dummy data
batch_size = 8
sequence_length = 1024
dummy_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
dummy_labels = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')

# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)

scaler = GradScaler()

# Set the model to training mode
model.train()

# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
    with autocast(dtype=torch.bfloat16, enabled=True):
        # Forward pass
        outputs = model(input_ids=dummy_input_ids, labels=dummy_labels)
        loss = outputs.loss

    # Backward pass
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")

Expected behavior

Running the code snippet above gives me the following error

{
	"name": "RuntimeError",
	"message": "expected scalar type BFloat16 but found Float",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 54
     51     loss = outputs.loss
     53 # Backward pass
---> 54 scaler.scale(loss).backward()
     55 scaler.step(optimizer)
     56 scaler.update()

File ~/anaconda3/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
    295     raise RuntimeError(
    296         \"Implementing both 'backward' and 'vjp' for a custom \"
    297         \"Function is not allowed. You should only implement one \"
    298         \"of them.\"
    299     )
    300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:882, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
    880     out = CompiledFunctionBackward.apply(*all_args)
    881 else:
--> 882     out = call_compiled_backward()
    884 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
    885 if CompiledFunction.maybe_subclass_metadata is not None:

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
    824     with tracing(saved_context), context(), track_graph_compiling(
    825         aot_config, \"backward\"
    826     ):
    827         CompiledFunction.compiled_bw = aot_config.bw_compiler(
    828             bw_module, placeholder_list
    829         )
--> 831 out = call_func_at_runtime_with_args(
    832     CompiledFunction.compiled_bw,
    833     all_args,
    834     steal_args=True,
    835     disable_amp=disable_amp,
    836 )
    838 out = functionalized_rng_runtime_epilogue(
    839     CompiledFunction.metadata, out
    840 )
    841 return tuple(out)

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:113, in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
    111 with context():
    112     if hasattr(f, \"_boxed_call\"):
--> 113         out = normalize_as_list(f(args))
    114     else:
    115         # TODO: Please remove soon
    116         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
    117         warnings.warn(
    118             \"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. \"
    119             \"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \"
    120             \"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\"
    121         )

File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    449 prior = set_eval_frame(callback)
    450 try:
--> 451     return fn(*args, **kwargs)
    452 finally:
    453     set_eval_frame(prior)

File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
     34 @functools.wraps(fn)
     35 def inner(*args, **kwargs):
---> 36     return fn(*args, **kwargs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:906, in CompiledFxGraph.__call__(self, inputs)
    905 def __call__(self, inputs: List[Any]) -> Any:
--> 906     return self.get_current_callable()(inputs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:784, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
    782 def run(new_inputs):
    783     copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 784     return model(new_inputs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:934, in _run_from_cache(compiled_graph, inputs)
    926     assert compiled_graph.artifact_path
    927     compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
    928         compiled_graph.cache_key,
    929         compiled_graph.artifact_path,
    930         compiled_graph.cache_linemap,
    931         compiled_graph.constants,
    932     ).call
--> 934 return compiled_graph.compiled_artifact(inputs)

File /tmp/torchinductor_zcai75/wq/cwqm67koqia7gthn65wgmhppfzrfyheocl4px7fecurpkfigigfs.py:1751, in call(args)
   1749 buf39 = reinterpret_tensor(buf34, (48, 64, 1024), (65536, 1024, 1), 0); del buf34  # reuse
   1750 # Source Nodes: [], Original ATen: [aten.bmm]
-> 1751 extern_kernels.bmm(permute_103, reinterpret_tensor(buf38, (48, 1024, 1024), (1048576, 1024, 1), 0), out=buf39)
   1752 del permute_103
   1753 buf41 = empty_strided_cuda((8, 6, 1024, 64), (393216, 65536, 64, 1), torch.bfloat16)

RuntimeError: expected scalar type BFloat16 but found Float"
}

This problem does not seem to happen for a GPT2 model. If I initialize the GPT2Config instead of LlamaConfig in the commented code in the script, there is no such error.

JackCai1206 avatar May 21 '24 18:05 JackCai1206