trlx icon indicating copy to clipboard operation
trlx copied to clipboard

FasterTransformer reward model support

Open LouisCastricato opened this issue 3 years ago • 8 comments

🚀 The feature, motivation, and pitch

We need the ability to use massive reward models, as this will be necessary for our Instruct GPT model. Currently the size of the reward model is greatly limited and using GPU accelerators for them comes with weird sets of limitations.

Alternatives

We could alternatively use a different accelerate script for the reward model, or include the reward model within the student class. Doing the latter would be trivial but result in kind of gross code and not very easily extendible.

Additional context

No response

LouisCastricato avatar Oct 19 '22 19:10 LouisCastricato

Couldn't you run RM on a separate node and send requests for comparison over the network?

vblagoje avatar Nov 07 '22 14:11 vblagoje

You could! I think thats actually along the lines of the solutions we're looking at (also doing roll outs on a separate node for PPO). I think we want a very easy solution for the end user though -- where they don't really need to think about the size of their reward model if they have enough GPU horsepower.

LouisCastricato avatar Nov 07 '22 14:11 LouisCastricato

Cool, not sure if it's a good fit but you could deploy the RM in Triton and invoke it via tritonclient. That's what I would do but perhaps it's not a good fit for your end users.

vblagoje avatar Nov 07 '22 15:11 vblagoje

Yeah that's what we're doing internally.

LouisCastricato avatar Nov 07 '22 15:11 LouisCastricato

Could I ask what reward models you are using? Seems rare to find one.

We need the ability to use massive reward models

James4Ever0 avatar Dec 24 '22 01:12 James4Ever0

I was using FLAN T5 11B zero shot. @Dahoas has multiple 6B finetuned RMs though.

LouisCastricato avatar Dec 24 '22 02:12 LouisCastricato

FLAN T5 11B

I've reviewed your code (or you have modified it somehow). I think the prompt format needs change to adapt multiline answers.

# give prompt and compare which is better?
# to predict logits, (decide is A or B)

# reference: https://yjernite.github.io/lfqa.html
special_token = "<P>" # hey you make sure this exists in the tokenizer, since it differs with model

def replaceTillNothingLeft(string:str, objective:str,target:str=""):
    while objective in string:
        string = string.replace(objective,target)
    return string

RTNL = lambda x: replaceTillNothingLeft(x,special_token) # shame we cannot typehint you! is it?
question = "What is the most beautiful thing in this world?"
ans_0 = "Frog."
ans_1 = "Cat."
mprompt = f"Given the question and two answers, find the better answer.{special_token}Question: {RTNL(question)}{special_token}A: {RTNL(ans_0)}{special_token}B: {RTNL(ans_1)}{special_token}Mark it as A or B."

James4Ever0 avatar Dec 25 '22 01:12 James4Ever0

I don't think so? It works fine without this.

LouisCastricato avatar Dec 25 '22 02:12 LouisCastricato

Addressed by Triton Inference Server client https://github.com/CarperAI/trlx/tree/add-hh-example

cat-state avatar Feb 02 '23 15:02 cat-state