unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

New GRPO doesnt support models besides LLAMA - (Mistral)

Open advpropsys opened this issue 1 year ago • 15 comments

To reproduce: take any GRPO script with Mistral (for example) instead of Llama, and it will fail on head matmul

Tested in docker or/and on clean venv install. Llama works just fine.

[rank0]: File "xxx/rl-scripts/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 80, in compute_loss [rank0]: new_logits = torch.matmul(new_hidden_states, lm_head.t())

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7f477665f1c0>(*(GradTrackingTensor(lvl=1, value=.... [rank0]: FakeTensor(..., device='cuda:0', size=(1, s0, 32000), dtype=torch.bfloat16, [rank0]: requires_grad=True) [rank0]: ), GradTrackingTensor(lvl=1, value= [rank0]: FakeTensor(..., device='cuda:0', size=(4096, 32000), dtype=torch.bfloat16) [rank0]: )), **{}): [rank0]: a and b must have same reduction dim, but got [s0, 32000] X [4096, 32000].

Error arises in UnslothEfficientGRPO.apply, I have tried multiple fixes in rl_replacements (in zoo) with no luck - it needs a deeper investigation

Example:

from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "mistralai/Mistral-7B-Instruct-v0.2",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)
import re
from datasets import load_dataset, Dataset
global COUNTER
COUNTER = 0
global PRINT_EVERY
PRINT_EVERY = 20

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    global COUNTER
    if COUNTER % PRINT_EVERY == 0:
        print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    COUNTER += 1
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

max_prompt_length = 256
from trl import GRPOConfig, GRPOTrainer

# Optional extra params for vLLM
from unsloth import vLLMSamplingParams
vllm_sampling_params = vLLMSamplingParams(
    min_p = 0.01,
    seed = 3407,
)
training_args = GRPOConfig(
    learning_rate = 5e-6,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    report_to = "none", # Can use Weights & Biases
    vllm_sampling_params = vllm_sampling_params, # Optional
    temperature = 1.0,
)
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()```

I have tried many combinations of train args, model configs - nothing  

advpropsys avatar Feb 21 '25 19:02 advpropsys

I have

unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str

How i can fix this?

xellDart avatar Feb 22 '25 00:02 xellDart

@advpropsys

Did you try with Phi-4 or Qwen and it still errors?

shimmyshimmer avatar Feb 22 '25 11:02 shimmyshimmer

I have

unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str

How i can fix this?

probably an issue with dataset generation, try loading the json in pandas, then convert it to dataset

pd.read_json() then load_dataset() has a attribute to load pandas dataset

shimmyshimmer avatar Feb 22 '25 11:02 shimmyshimmer

@advpropsys

Did you try with Phi-4 or Qwen and it still errors?

@shimmyshimmer Phi4-bnb-4bit does indeed work! Apparently it's only mistral problem....

advpropsys avatar Feb 22 '25 13:02 advpropsys

I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?

probably an issue with dataset generation, try loading the json in pandas, then convert it to dataset

pd.read_json() then load_dataset() has a attribute to load pandas dataset

similiar error in unsloth 2025.02.12 for Phi-4, seems 2025.02.04 is ok

--update-- sorry, 2025.02.12 also is ok for Phi-4, seems just reinstall unsloth fix my env

AiHaibara avatar Feb 24 '25 07:02 AiHaibara

I have

unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str

How i can fix this?

I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by: pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6" replace your cuda and torch version

KuroKienDinh avatar Feb 25 '25 03:02 KuroKienDinh

I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?

I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by: pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6" replace your cuda and torch version

does the latest version of unsloth not work for mistral still? and downgrading it seems to work?

shimmyshimmer avatar Feb 25 '25 03:02 shimmyshimmer

I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?

I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by: pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6" replace your cuda and torch version

does the latest version of unsloth not work for mistral still? and downgrading it seems to work?

I think so

KuroKienDinh avatar Feb 25 '25 04:02 KuroKienDinh

I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?

I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by: pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6" replace your cuda and torch version

does the latest version of unsloth not work for mistral still? and downgrading it seems to work?

Older version works. New efficient grpo doesn't

advpropsys avatar Feb 25 '25 08:02 advpropsys

The issue is that MistralForCausalLM_fast_forward always returns logits instead of hidden states. The fix is on the way #1831

oKatanaaa avatar Feb 25 '25 18:02 oKatanaaa

The issue is that MistralForCausalLM_fast_forward always returns logits instead of hidden states. The fix is on the way #1831

I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f50160e227936e0011d67e3bc2472c2089629:

File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] ~~~~~~^^^^^^^^^^^^^^ TypeError: list indices must be integers or slices, not str

I try commit ID 179840d3a7b49188c372b56c67c4290d53c29ed6 still the same

and commit ID 512fec6a7b77a930b85a5b5685bf056fbb29ff5e works for me

any suggestion?

xudou3 avatar Feb 26 '25 09:02 xudou3

The issue is that MistralForCausalLM_fast_forward always returns logits instead of hidden states. The fix is on the way #1831

I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f501:

File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]

TypeError: list indices must be integers or slices, not str

I try commit ID 179840d still the same

and commit ID 512fec6 works for me

any suggestion?

That seems to be a completely separate problem. Make an issue for that and attach (a) code to reproduce the problem and (b) full error trace.

oKatanaaa avatar Feb 26 '25 09:02 oKatanaaa

The issue is that MistralForCausalLM_fast_forward always returns logits instead of hidden states. The fix is on the way #1831

I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f501:

File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]

TypeError: list indices must be integers or slices, not str

I try commit ID 179840d still the same and commit ID 512fec6 works for me any suggestion?

That seems to be a completely separate problem. Make an issue for that and attach (a) code to reproduce the problem and (b) full error trace.

in https://github.com/unslothai/unsloth/issues/1836

xudou3 avatar Feb 26 '25 09:02 xudou3

This should fix the issue: https://github.com/unslothai/unsloth/issues/1836#issuecomment-2685898012

kings-crown avatar Feb 26 '25 18:02 kings-crown

@AiHaibara @xellDart @xudou3 @kings-crown Mistral should work courtesy of @oKatanaaa :)

For Colab / Kaggle, please restart and run all. For local machines, please do:

pip install --force-reinstall --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo

danielhanchen avatar Mar 05 '25 13:03 danielhanchen