mushroom-rl
mushroom-rl copied to clipboard
Question: TorchApproximator.predict - Why no torch.no_grad() and why call forward directly?
Is your feature request related to a problem? Please describe.
I was wondering why the predict method in class TorchApproximator calculates the gradients and calls self.network.forward(*torch_args, **kwargs)
.
Describe the solution you'd like
Why not use the with torch.no_grad()
statement to save memory on the one hand and to omit the detach()
call on the other hand. Further, calling self.network(*torch_args, **kwargs)
instead of forward is better practice (if there's no good reason for doing otherwise).
Describe alternatives you've considered
If it is the desired behaviour, that the output_tensor
flag means a tensor with gradients should be returned, the first point is void.
Additional context
Additionally, in line 254 in the same file, the .requires_grad_(False)
has no effect since the .detach()
call has already taken care of it.
about the torch.no_grad: we cannot use it in forward, as it is also used when computing the losses in actor-critic algorithms. I don't remember if there was a specific reason for using forward instead of the function call. It may be that we had some issues years ago, but I don't remember any reason. I don't have time to check this now.
output_tensor flag means a tensor with gradients should be returned, the first point is void. -> yes, it means to return a tensor with the gradient.
.requires_grad_(False) has no effect since the .detach() call has already taken care of it. -> could possibly be, but I'm not quite sure, and I have no time to check it now. We did this quite a long time ago, and there was a reasoning behind that.