GA3C icon indicating copy to clipboard operation
GA3C copied to clipboard

pyTorch

Open dylanthomas opened this issue 7 years ago • 6 comments

Isn't there any plan on the horizon to port this code to pyTorch ?

dylanthomas avatar Mar 16 '17 01:03 dylanthomas

We are not planning implementing it for now, but some people are indeed suggesting that pyTorch may be faster than TF. It would be great if someone can implement GA3C in pyTorch following our guidelines.

ifrosio avatar Mar 16 '17 16:03 ifrosio

I did a quick trial in one of my branches . Actually, TF is almost twice as fast, because the naive way I did the vectorized loss is probably involving a lot of function calls. The same issue arises for Chainer version. The loss takes almost more time to compute than the cnn. I think it could work faster if implementing it as a specific layer.

etienne87 avatar Mar 26 '17 13:03 etienne87

Just FYI, my friend was able to reproduce both the speed and performance of my a3c implementation with his pytorch code. It batches data differently from GA3C, but the overall structure is similar.

ppwwyyxx avatar Mar 26 '17 17:03 ppwwyyxx

interesting @ppwwyyxx ! My naive implementation gives something like this :

results txt

I am not sure if the problem is in the batching, rather than the explicit calls & many steps of computation for the loss.

        p, v = self.model.forward_multistep(x_, c, h)
        probs = F.softmax(p)
        probs = F.relu(probs - Config.LOG_EPSILON)
        log_probs = torch.log(probs) 
        adv = (rewards - v)
        adv = torch.masked_select(adv,mask)
        log_probs_a = torch.masked_select(log_probs,a) #we cannot use it because of variable length input
        piloss = -torch.sum( log_probs_a * Variable(adv.data), 0)  
        entropy = torch.sum(torch.sum(log_probs*probs,1),0) * self.beta
        vloss = torch.sum(adv.pow(2),0) / 2
        loss = piloss + entropy + vloss

If someone knows how to do this more quickly in pytorch ...?

etienne87 avatar Mar 26 '17 17:03 etienne87

@ppwwyyxx Is there a public git repo for your friend's pyTorch implementation ?

dylanthomas avatar Mar 28 '17 00:03 dylanthomas

Unfortunately no..

ppwwyyxx avatar Mar 28 '17 03:03 ppwwyyxx