trlx icon indicating copy to clipboard operation
trlx copied to clipboard

How to attribute reward to multiple model runs in the same trajectory with PPO

Open dpaleka opened this issue 3 years ago • 7 comments
trafficstars

I want to finetune a base model M to maximize a reward R, when the model is used inside of a more complex system. Take a simple example of the setting. The trajectory is as follows: sample prompt_1 from a dataset of prompts, then

prompt1 -> M(prompt1)  = out_1
out_1 -> F(out_1) = prompt_2
prompt_2 -> M(prompt_2) = out_2
out_2 -> R(out_2) = reward

where F : str -> str and R : str -> int are some methods defined in my code. Is there a way to do this in the current TRLX framework, preferably online with PPO? Alternative suggestions are welcome.

dpaleka avatar Oct 25 '22 18:10 dpaleka

summarizing from the discord: lets say M = m(params, x) then currently you should be able to do online RL with

def reward_fn(x):
    return m(params_0, F(x))

which would be something like

out_i = m(params_i, prompt1)
reward = reward_fn(out_i)

This would mean only the first M call is unfrozen, the one after F does not change.

Could you specify the task you are trying to achieve?

cat-state avatar Oct 26 '22 01:10 cat-state

Feedback from a conversation with @LouisCastricato and @Dahoas: This is not supported in the current TRLX version. The closest thing available is attributing the reward only to the first call of M, as in @cat-state's reply.

As of now, TRLX supports only RL setups where all "actions" to attribute the reward to are done before the reward function is called. The recommended solution is to write your own orchestrator to do this. @LouisCastricato says they might merge this if someone implements it correctly.

The main issue is distributed computing (which is on the roadmap), since the plan is to compute rollouts and rewards on different nodes by default soon. Distributed RL is way easier if the reward function is stateless.

My comments: If I understand the code correctly, implementing a separate orchestrator to do this right now might not be too difficult. I'm not sure if you should merge it because it might interfere with future parallelization improvements. If I decide to make a PR, I'll try to keep that in mind.

dpaleka avatar Oct 26 '22 14:10 dpaleka

+1, also interested in something like this.

As of now, TRLX supports only RL setups where all "actions" to attribute the reward to are done before the reward function is called.

@dpaleka, isn't this already the case in your very first pseudocode snippet? R is only called after both M calls, not in between, right?

paulbricman avatar Nov 01 '22 12:11 paulbricman

As of now, TRLX supports only RL setups where all "actions" to attribute the reward to are done before the reward function is called.

@dpaleka, isn't this already the case in your very first pseudocode snippet? R is only called after both M calls, not in between, right?

I miswrote; what is true is that TRLX assumes the "actions" are just a single model.generate call.

dpaleka avatar Nov 01 '22 12:11 dpaleka

Hey! I am open to rectifying this, I am just at capacity right now and I don't think we have the engineering manpower for this at the moment. @paulbricman @dpaleka if you two would be interested in implemented, I'd be happy to assign you and then review it.

LouisCastricato avatar Nov 01 '22 13:11 LouisCastricato

@paulbricman are you on the discord? https://discord.gg/canadagoose

LouisCastricato avatar Nov 01 '22 13:11 LouisCastricato

@paulbricman are you on the discord? https://discord.gg/canadagoose

Yes, just @paul on that server.

Hey! I am open to rectifying this, I am just at capacity right now and I don't think we have the engineering manpower for this at the moment. @paulbricman @dpaleka if you two would be interested in implemented, I'd be happy to assign you and then review it.

I'd be excited to help implement it, but I'm skeptical about whether I understand PPO well enough and whether I'm familiar enough with the trlx codebase to do it. I might be able to make a contribution if both @dpaleka and I would work on this?

Also thanks for the super quick reply!

paulbricman avatar Nov 02 '22 11:11 paulbricman