unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

Can we use unsloth to train Reward Models?

Open armsp opened this issue 1 year ago • 11 comments

More of a question than a bug - will you be working on some examples to use unsloth for training Reward Models - https://huggingface.co/docs/trl/main/en/reward_trainer - as well?

armsp avatar Mar 19 '24 06:03 armsp

@armsp LoRA and QLoRA for reward models, PPO, DPO etc are all supported - ie anything TRL does, we can do :) But it just needs to be LoRA / QLoRA

danielhanchen avatar Mar 19 '24 08:03 danielhanchen

@danielhanchen thats amazing...I was just wondering if there are some docs/examples?

armsp avatar Mar 19 '24 08:03 armsp

@armsp Sadly I don't - I have DPO, but the rest you'll have to read the TRL docs

danielhanchen avatar Mar 19 '24 08:03 danielhanchen

If i figure it myself maybe I will post it here...meanwhile feel free to close this issue :)

armsp avatar Mar 19 '24 08:03 armsp

UPDATE:
I got it to work...and there is nothing to it...it just works!!

armsp avatar Mar 19 '24 09:03 armsp

Fantastic!

danielhanchen avatar Mar 19 '24 09:03 danielhanchen

@danielhanchen I have observed some quirky behaviors though - for example, for the reward model we only need the following target modules -

 target_modules=[
    "q_proj",
    "v_proj"])

but when I remove the other modules there is an assertion error.

Also, when we initialize the tokenizer, how do we pass arguments for padding and truncation ?

armsp avatar Mar 19 '24 13:03 armsp

I think I spoke too soon...it completes the training loop but when the trainer goes into the evaluation loop then it errors if the default parameters have been changed - for example num_labels=1 (it is 2 by default) which leads me to believe that somehow the parameters are not being propagated to the code that lies below the abstraction of unsloth. For example: because of that, the error that comes is -

  File "my_venv/lib64/python3.10/site-packages/trl/trainer/utils.py", line 552, in compute_accuracy
    accuracy = np.array(predictions == labels, dtype=float).mean().item()
ValueError: operands could not be broadcast together with shapes (9024002,2) (9024002,) 

armsp avatar Mar 20 '24 06:03 armsp

@armsp Oh no :( I'll check again and get back to you - sorry on the issue!

danielhanchen avatar Mar 20 '24 11:03 danielhanchen

That's great if this works out of the box. I'd be keen to try ORPO with it.

RonanKMcGovern avatar Mar 25 '24 09:03 RonanKMcGovern

Extreme apologies been extremely busy on my end - so apologies again didn't have time to look at this :(

danielhanchen avatar Mar 27 '24 17:03 danielhanchen