New GRPO doesnt support models besides LLAMA - (Mistral)
To reproduce: take any GRPO script with Mistral (for example) instead of Llama, and it will fail on head matmul
Tested in docker or/and on clean venv install. Llama works just fine.
[rank0]: File "xxx/rl-scripts/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 80, in compute_loss [rank0]: new_logits = torch.matmul(new_hidden_states, lm_head.t())
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7f477665f1c0>(*(GradTrackingTensor(lvl=1, value=.... [rank0]: FakeTensor(..., device='cuda:0', size=(1, s0, 32000), dtype=torch.bfloat16, [rank0]: requires_grad=True) [rank0]: ), GradTrackingTensor(lvl=1, value= [rank0]: FakeTensor(..., device='cuda:0', size=(4096, 32000), dtype=torch.bfloat16) [rank0]: )), **{}): [rank0]: a and b must have same reduction dim, but got [s0, 32000] X [4096, 32000].
Error arises in UnslothEfficientGRPO.apply, I have tried multiple fixes in rl_replacements (in zoo) with no luck - it needs a deeper investigation
Example:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "mistralai/Mistral-7B-Instruct-v0.2",
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # Remove QKVO if out of memory
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
import re
from datasets import load_dataset, Dataset
global COUNTER
COUNTER = 0
global PRINT_EVERY
PRINT_EVERY = 20
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
global COUNTER
if COUNTER % PRINT_EVERY == 0:
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
COUNTER += 1
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
max_prompt_length = 256
from trl import GRPOConfig, GRPOTrainer
# Optional extra params for vLLM
from unsloth import vLLMSamplingParams
vllm_sampling_params = vLLMSamplingParams(
min_p = 0.01,
seed = 3407,
)
training_args = GRPOConfig(
learning_rate = 5e-6,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit",
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 6, # Decrease if out of memory
max_prompt_length = max_prompt_length,
max_completion_length = max_seq_length - max_prompt_length,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 250,
report_to = "none", # Can use Weights & Biases
vllm_sampling_params = vllm_sampling_params, # Optional
temperature = 1.0,
)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
trainer.train()```
I have tried many combinations of train args, model configs - nothing
I have
unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str
How i can fix this?
@advpropsys
Did you try with Phi-4 or Qwen and it still errors?
I have
unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str
How i can fix this?
probably an issue with dataset generation, try loading the json in pandas, then convert it to dataset
pd.read_json() then load_dataset() has a attribute to load pandas dataset
Did you try with Phi-4 or Qwen and it still errors?
@shimmyshimmer Phi4-bnb-4bit does indeed work! Apparently it's only mistral problem....
I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?
probably an issue with dataset generation, try loading the json in pandas, then convert it to dataset
pd.read_json() then load_dataset() has a attribute to load pandas dataset
similiar error in unsloth 2025.02.12 for Phi-4, seems 2025.02.04 is ok
--update-- sorry, 2025.02.12 also is ok for Phi-4, seems just reinstall unsloth fix my env
I have
unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str
How i can fix this?
I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by: pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6" replace your cuda and torch version
I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?
I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by:
pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6"replace your cuda and torch version
does the latest version of unsloth not work for mistral still? and downgrading it seems to work?
I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?
I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by:
pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6"replace your cuda and torch versiondoes the latest version of unsloth not work for mistral still? and downgrading it seems to work?
I think so
I have unsloth_compiled_cache/UnslothGRPOTrainer.py", line 690, in compute_loss [rank0]: prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] [rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]: TypeError: list indices must be integers or slices, not str How i can fix this?
I had same issue then I downgrade to unsloth-2025.2.9 work well. FYI if you use window you can install by:
pip install "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git@179840d3a7b49188c372b56c67c4290d53c29ed6"replace your cuda and torch versiondoes the latest version of unsloth not work for mistral still? and downgrading it seems to work?
Older version works. New efficient grpo doesn't
The issue is that MistralForCausalLM_fast_forward always returns logits instead of hidden states. The fix is on the way #1831
The issue is that
MistralForCausalLM_fast_forwardalways returns logits instead of hidden states. The fix is on the way #1831
I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f50160e227936e0011d67e3bc2472c2089629:
File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] ~~~~~~^^^^^^^^^^^^^^ TypeError: list indices must be integers or slices, not str
I try commit ID 179840d3a7b49188c372b56c67c4290d53c29ed6 still the same
and commit ID 512fec6a7b77a930b85a5b5685bf056fbb29ff5e works for me
any suggestion?
The issue is that
MistralForCausalLM_fast_forwardalways returns logits instead of hidden states. The fix is on the way #1831I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f501:
File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
TypeError: list indices must be integers or slices, not strI try commit ID 179840d still the same
and commit ID 512fec6 works for me
any suggestion?
That seems to be a completely separate problem. Make an issue for that and attach (a) code to reproduce the problem and (b) full error trace.
The issue is that
MistralForCausalLM_fast_forwardalways returns logits instead of hidden states. The fix is on the way #1831I still have this problem while training Llama-3.2-1B-Instruct at commit 2c0f501:
File "/code/unsloth_20250226/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 766, in compute_loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
TypeError: list indices must be integers or slices, not strI try commit ID 179840d still the same and commit ID 512fec6 works for me any suggestion?
That seems to be a completely separate problem. Make an issue for that and attach (a) code to reproduce the problem and (b) full error trace.
in https://github.com/unslothai/unsloth/issues/1836
This should fix the issue: https://github.com/unslothai/unsloth/issues/1836#issuecomment-2685898012
@AiHaibara @xellDart @xudou3 @kings-crown Mistral should work courtesy of @oKatanaaa :)
For Colab / Kaggle, please restart and run all. For local machines, please do:
pip install --force-reinstall --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo