Mava icon indicating copy to clipboard operation
Mava copied to clipboard

[INVESTIGATION] PPO value masking

Open sash-a opened this issue 1 year ago • 0 comments

What do you want to investigate?

It seems like the values of dead agents are not being masked out in the MAPGWithTrustRegionClippingLoss.

Specifically we get a values of states like this:

distribution_params, values = network.network.apply(params, observations)

These exact values are then used like this:

unclipped_value_error = target_values - values

(Note: it seems that target_values are also not masked out, it is possible this cancels out the non-masking of values, but ideally they should both be masked out)

This means that the values of dead agents are not being masked out. At least that is what it seems like. It is possible that all these observations are of living agents and as such no masking out is needed, however this does not seem to be the case.

Another thing to note, it seems that rlax.truncated_generalized_advantage_estimation does mask out advantages of dead agents as we pass in discounts which we calculate using terminals. This should then correctly mask dead agents in the policy loss. However it also effects target values which are calculated as target_values = values[:-1] + advantages. So it is likely that target_values should also be masked in the loss.

I think the easiest way to implement this is to pass termination into batch and simply multiply termination with values and target values inside the loss. For sanity's sake you could also multiply log probs or the final policy loss by termination (I think terminals is probably a better name for this variable).

Definition of done

Once an implementation of value masking is done, it should be benchmarked against develop and the most performant branch should be kept.

Future Investigations

Something that should also be investigated is if we do masking should we call jnp.mean(loss) or should we call jnp.sum(loss)/jnp.sum(mask) (ie only mean the living agents)

sash-a avatar Jul 13 '22 10:07 sash-a