tianshou
tianshou copied to clipboard
Fix handling of torch "device" association
- [ ] I have marked all applicable categories:
- [x] exception-raising bug
- [ ] RL algorithm bug
- [ ] documentation request (i.e. "X is missing from the documentation.")
- [ ] new feature request
- [X] I have visited the source website
- [X] I have searched through the issue tracker for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
import tianshou, gym, torch, numpy, sys print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.4.11 0.21.0 1.12.1.post200 1.23.5 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0] linux
I think a simple, clear and robust mechanism should be designed to move policies between devices.
The only current way I have found to move a saved SAC policy on CUDA to CPU is the following:
policy: SACPolicy = torch.load(model_file_name, map_location="cpu")
policy.to(device)
policy.device =device
policy.actor.device = device
policy.actor.preprocess.device = device
policy.actor.preprocess.model.device =tdevice
policy.actor.mu.device = device
policy.actor.sigma.device = device
policy.critic1.device = device
policy.critic1.preprocess.device = device
policy.critic1.preprocess.model.device = device
policy.critic1.last.device = device
policy.critic1_old.device = device
policy.critic1_old.preprocess.device = device
policy.critic1_old.preprocess.model.device =tdevice
policy.critic1_old.last.device = device
policy.critic2.device = device
policy.critic2.preprocess.device = device
policy.critic2.preprocess.model.device = device
policy.critic2.last.device = device
policy.critic2_old.device = device
policy.critic2_old.preprocess.device = device
policy.critic2_old.preprocess.model.device = device
policy.critic2_old.last.device = device
policy.actor_optim.load_state_dict(policy.actor_optim.state_dict())
policy.critic1_optim.load_state_dict(policy.critic1_optim.state_dict())
policy.critic2_optim.load_state_dict(policy.critic2_optim.state_dict())
Some ideas may include, to specify any device variable in any (sub) model as a reference to a shared variable residing in just one place.
Another idea is to make .device a function that makes some call to some variable etc.
As a torch.nn.Module
, policy.to(device)
will recursively move all parameters and buffers to the same device. See discussions such as 1 and 2.
As a
torch.nn.Module
,policy.to(device)
will recursively move all parameters and buffers to the same device. See discussions such as 1 and 2.
Thanks @nuance1979 but in this tianshou implementation, unfortunately it is not enough to just use simply "policy.to" since tianshou uses clases like MLP and Net and others that has a field called "device" and this field is used at runtime to "move", i.e, to make another .to(device) to input tensors.
Since this attributes, appear every where, as in the example I show, either you find it by recursively exploring a tree of attributes with the dangers this has or you must known a priori the whole structure of the policy and make manual change for every .device attribute
The problem , basically us that these ".device" atts are not part of torch.nn.Modules but they live in other kind of objects.
I see. Then I'd suggest adopting one of the proposed methods in here.
Handling of devices is currently complicated and duplicated throughout tianshou, it should definitely be improved. I think @opcode81 started to look into it. Adding this to the release milestone
Indeed, this is a bug in Tianshou. The only way to move a policy to a new device is to apply a function like this:
def tianshou_module_with_device(m: torch.nn.Module, device: TDevice):
m = m.to(device)
for submodule in m.modules():
if hasattr(submodule, "device"):
submodule.device = device
return m
The loop should not be necessary - but it is.
If the device
member isn't set correctly in all affected submodules, then newly created objects will be moved to the wrong device. The correct way to handle this is to just dynamically retrieve the device of a tensor/module that is known to be already associated with the desired device rather than to store the device in a member.
@arnaujc91
This is also a good first issue and shouldn't be too hard to fix. It's also quite important (though I almost forgot about it). If you want, feel free to have a look
Interesting, will take a look!