System Info

transformers 4.41.0 torch 2.3.0 GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3

import torch
from transformers import LlamaConfig, LlamaForCausalLM, AdamW, AutoModelForCausalLM, GPT2Config
from torch.cuda.amp import autocast, GradScaler

# Configure the model
config = LlamaConfig(
    intermediate_size=1536,  # Typically 4 * hidden_size
    vocab_size=30522,        # Standard vocabulary size

# config = GPT2Config(
#     n_embd=384,
#     n_head=6,
#     n_layer=6,
#     n_positions=1024,
#     n_ctx=1024,
#     n_vocab=30522,
# )

# Initialize the model
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager").to('cuda')

# Compile the model (Torch 2.0 and above)
model = torch.compile(model)

# Create dummy data
batch_size = 8
sequence_length = 1024
dummy_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
dummy_labels = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')

# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)

scaler = GradScaler()

# Set the model to training mode

# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
    with autocast(dtype=torch.bfloat16, enabled=True):
        # Forward pass
        outputs = model(input_ids=dummy_input_ids, labels=dummy_labels)
        loss = outputs.loss

    # Backward pass

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")

Expected behavior

Running the code snippet above gives me the following error

	"name": "RuntimeError",
	"message": "expected scalar type BFloat16 but found Float",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 54
     51     loss = outputs.loss
     53 # Backward pass
---> 54 scaler.scale(loss).backward()
     55 scaler.step(optimizer)
     56 scaler.update()

RuntimeError: expected scalar type BFloat16 but found Float"

This problem does not seem to happen for a GPT2 model. If I initialize the GPT2Config instead of LlamaConfig in the commented code in the script, there is no such error.

JackCai1206 avatar May 21 '24 18:05 JackCai1206