rlpyt
rlpyt copied to clipboard
Why DQN related agents/algos default to cpu for loss computation ?
Hello @astooke , great work with this wonderful library. Working with DQN/Cat_DQN, it seems that the method __call__
of a DqnAgent
/CatDqnAgent
/R2d1Agent
will transfer the result of their computation on cpu. This is the same when calling their target
method. This leads to a side effect where the loss is computed on the cpu device, which slow down the training process especially when your action space is large (1000 discrete actions in my case). Also, in this context, the cost for transfering the whole q(s,a) data to the cpu is non negligeable. Is there any particular reason why this has to be done that way? would it be possible to further delay the transfer of data to cpu and restrict it to the required one? from my point of view, it is not necessy to transfer the q(s,a) data to the cpu, but only the td_abs_errors
data when computing the loss. Does that make sense? Or am I missing something here? I would be glad to make a PR regarding this issue if what I am saying makes sense. Thanks.
Hi! Thanks for the kind words. :)
And good question. No, I don't think it's necessary for the loss to be computed on the CPU, great catch! I had kind of left this hanging before and never got back to it. The thing that does have to go onto the CPU is the Q-values during sampling, so they can be logged in agent_info
in case you want to look at them or do something else with them--but if you're not, you could just turn this off.
Would be great to have it stay on GPU as much as possible, whenever that's faster!