mushroom-rl icon indicating copy to clipboard operation
mushroom-rl copied to clipboard

Question: TorchApproximator.predict - Why no torch.no_grad() and why call forward directly?

Open VanillaWhey opened this issue 3 years ago • 1 comments

Is your feature request related to a problem? Please describe. I was wondering why the predict method in class TorchApproximator calculates the gradients and calls self.network.forward(*torch_args, **kwargs).

Describe the solution you'd like Why not use the with torch.no_grad() statement to save memory on the one hand and to omit the detach() call on the other hand. Further, calling self.network(*torch_args, **kwargs) instead of forward is better practice (if there's no good reason for doing otherwise).

Describe alternatives you've considered If it is the desired behaviour, that the output_tensor flag means a tensor with gradients should be returned, the first point is void.

Additional context Additionally, in line 254 in the same file, the .requires_grad_(False) has no effect since the .detach() call has already taken care of it.

VanillaWhey avatar Dec 19 '20 13:12 VanillaWhey

about the torch.no_grad: we cannot use it in forward, as it is also used when computing the losses in actor-critic algorithms. I don't remember if there was a specific reason for using forward instead of the function call. It may be that we had some issues years ago, but I don't remember any reason. I don't have time to check this now.

output_tensor flag means a tensor with gradients should be returned, the first point is void. -> yes, it means to return a tensor with the gradient.

.requires_grad_(False) has no effect since the .detach() call has already taken care of it. -> could possibly be, but I'm not quite sure, and I have no time to check it now. We did this quite a long time ago, and there was a reasoning behind that.

boris-il-forte avatar Dec 21 '20 09:12 boris-il-forte