ReinforcementLearning.jl icon indicating copy to clipboard operation
ReinforcementLearning.jl copied to clipboard

PPO with MaskedPPOTrajectory

Open navaxel opened this issue 2 years ago • 2 comments

Hello ! I've been using PPOPolicy from your project with a MaskedPPOTrajetory. I think there are 2 issues within the update function :

  1. The legal_action_mask creation line 266 : lam = select_last_dim(flatten_batch(select_last_dim(LAM, 2:n+1)), inds)

By applying flatten_batch(), the mask dimension is lost and and error occurred at line 303 when it comes to add the mask on the AC network output : logit′ = raw_logit′ .+ ifelse.(lam, 0.0f0, typemin(Float32))

Maybe deleting the function flatten_batch()at line 266 would fix the problem (I tried locally this option and it worked) : lam = select_last_dim(select_last_dim(LAM, 2:n+1), inds)

  1. The entropy_loss computation line 308 : entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)

p′ is the softmax output from the AC network output with some values at -Inf when the mask did not allow certain action. Therefore, some p′ values are at 0 and log_p′ values at -Inf. This multiplication returns Nan and breaks the entire update. I tried to compute the multiplication myself by putting 0 when this kind of product happens and it fixed the problem :

sum_p_logp = 0.0
for (i,val) in enumerate(p′)
   if val != 0
      sum_p_logp += val * log_p′[i]
   end
end
entropy_loss = -sum_p_logp * 1 // size(p′, 2)

If you can give me your opinion on this issue, I'd be grateful ! Thanks !

navaxel avatar Jun 30 '23 21:06 navaxel

Hello,

This package is currently under a heavy rework that was stopped when unfinished. The current maintainers are trying to fix everything little by little but most algorithms are currently broken, especially on the master branch (which is quite different to the lastest release). If you feel like you are sufficiently knowledgeable with PPO to fix it, we'd be happy to help you do that in a PR. Note however that this may need you will have to get acquainted with the refactoring. There is a documentation page to help you contribute an algorithm here (use the one on the main branch as the release version is obsolete).

I myself have worked with PPO so I know it and may be able to help but I never worked with the RL.jl implementation. Don't hesitate to tag me if you need.

Regarding your issues specifically:

  1. I never worked with masks so I'll have to look this up when I can to understand the issue.
  2. It is indeed a problem because this behaves as if 0*Inf == 0 but it's not. I understand that your loop is a quick fix but it is inefficient. A preferable approach would be to compute log_p' as log(p' + \eps) where eps is a small number (like 1f-3) to keep it positive. This is the standard way of doing it and is much faster, also it works on a GPU.

HenriDeh avatar Jul 06 '23 08:07 HenriDeh

Regarding the 2nd issue:

According to this document(which is referenced in here), the small number used for mask is -1f8. I tried nextfloat(typemin(Float32)) instead of typemin(Float32) and found that it doesn't return NaN anymore at least with my enviroment. Maybe prob on l.159 should be updated in the same way.

qwjyh avatar Dec 30 '23 01:12 qwjyh

This function was dropped from the package with the RLZoo deprecation

jeremiahpslewis avatar Mar 27 '24 07:03 jeremiahpslewis