stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

Loading GPU trained RPPO on CPU

Open norikazu99 opened this issue 2 years ago • 7 comments

Hello, when attempting to load my gpu trained RPPO I get the following error. (Note: I only want to use the model's predictions I don't necessarily want to resume training.) I used sb3contrib bleeding-edge-version which I installed using "pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/" and python version is 3.10.9

Code to reproduce: model = RecurrentPPO.load("/path/model.zip", device=torch.device('cpu'))

Error generated: UserWarning: Could not deserialize object _last_lstm_states. Consider using "custom_objects" argument to replace this object. Exception: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. warnings.warn(

When attempting the following instead, another error is generated. Code to reproduce: custom_objects = {'_last_lstm_states': None}

model = RecurrentPPO.load("/path/model.zip", custom_objects=custom_objects, device=torch.device('cpu'))

Error generated: Segmentation fault: 11

norikazu99 avatar Feb 28 '23 17:02 norikazu99

Hello, this is indeed a known problem but it should not prevent you from using it at test time: it is only a UserWarning not an error. Or do you get an actual error? (not when using custom object)

Another solution is to load the model on gpu again, convert to cpu (and convert _last_lstm_states to cpu) and then save it.

araffin avatar Feb 28 '23 21:02 araffin

Hello @araffin , I really appreciate the quick response. Python file or kernel crashes a couple of seconds after UserWarning, so I'm not able to use for testing. Pytorch and sb3contrib versions are the same on both machines however machine 1 has stablebaselines3 version 1.8.0a4 and second machine has version 1.8.0a7, installed could that be the problem?

It's pretty late at night where I am so will definitely lookup how to do the other solution first thing tomm morning. Once again, thanks for your help.

norikazu99 avatar Mar 01 '23 03:03 norikazu99

Python file or kernel crashes a couple of seconds after UserWarning

What is your PyTorch version? Could you try upgrading?

has stablebaselines3 version 1.8.0a4 and second machine has version 1.8.0a7, installed could that be the problem?

this should not be a problem, but in case, try to always have the same version.

araffin avatar Mar 02 '23 09:03 araffin

The first machine has pytorch 1.13.1 with cuda (windows os) and the second machine has torch 1.13.1 without cuda (mac os). I believe it's the latest version, just installed it using conda.

norikazu99 avatar Mar 02 '23 18:03 norikazu99

Oh, I see, looks like a problem with PyTorch between OS. Can you confirm that you have no problem in case you save and load on the same machine?

araffin avatar Mar 03 '23 09:03 araffin

Hello, I currently don't have access to the windows machine, but I can confirm that with stablebaselines3 I'm able to load a PPO model on my mac with device 'cpu', which was trained on the windows machine using gpu. So it has to do with the loading of "_lstm_states". Will confirm if I can save and load on the same machine as soon as I have access to the windows machine.

norikazu99 avatar Mar 10 '23 22:03 norikazu99

I solved this by downgrading numpy.

https://github.com/ray-project/ray/issues/31293#issuecomment-1364667180

frasermcghan avatar May 24 '23 17:05 frasermcghan