stats missing when using pretrained pi0 policy
System Info
- `lerobot` version: 0.1.0
- Platform: Linux-5.14.0-284.86.1.el9_2.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.11
- Huggingface_hub version: 0.28.1
- Dataset version: 3.2.0
- Numpy version: 2.1.3
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Cuda version: 12040
- Using GPU in script?: True
Information
- [x] One of the scripts in the examples/ folder of LeRobot
- [x] My own task or dataset (give details below)
Reproduction
First of all, thanks for open-sourcing the amazing pi0 codebase.
To reproduce my error
-
first instantiate the policy via
policy = PI0Policy.from_pretrained("lerobot/pi0") -
then, modify the config.json it downloaded to replace the empty
input featureswith
"input_features": {
"observation.image.top": {
"shape": [
3,
224,
224
],
"type": "VISUAL"
},
"observation.image.left": {
"shape": [
3,
224,
224
],
"type": "VISUAL"
},
"observation.image.right": {
"shape": [
3,
224,
224
],
"type": "VISUAL"
},
"observation.state": {
"shape": [
7
],
"type": "STATE"
}
},
- I then acquired pictures and robot state from a simulator (Simpler+Maniskill, to be precise, but I don't think this matters much?), and then packed them into a dictionary to feed to
select_action
I then got
Traceback (most recent call last):
File "/scratch/zf540/pi0/aqua-vla/experiments/envs/simpler/test_ckpts_in_simpler.py", line 222, in <module>
eval_simpler()
File "/scratch/zf540/pi0/aqua-vla/.venv/lib/python3.11/site-packages/draccus/argparsing.py", line 225, in wrapper_inner
response = fn(cfg, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/aqua-vla/experiments/envs/simpler/test_ckpts_in_simpler.py", line 174, in eval_simpler
action = policy.select_action(observation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/aqua-vla/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/lerobot/lerobot/common/policies/pi0/modeling_pi0.py", line 276, in select_action
batch = self.normalize_inputs(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/aqua-vla/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/aqua-vla/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/aqua-vla/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/scratch/zf540/pi0/lerobot/lerobot/common/policies/normalize.py", line 155, in forward
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: `mean` is infinity. You should either initialize with `stats` as an argument, or use a pretrained model.
From reading issue #293 I thought if you load a model with from_pretrained, then you do not need to specify dataset stats?
Expected behavior
I expect to see some action output from the model. Doesn't have to be a correct rollout, but I want to see the pipeline working so I can tweak it further.
same problem
same problem when running pi0 on so100
return self._call_impl(*args, **kwargs)
File "/home/nikhil/miniconda3/envs/lerobot/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/nikhil/miniconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/nikhil/lerobot/lerobot/common/policies/normalize.py", line 169, in forward
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
AssertionError: `mean` is infinity. You should either initialize with `stats` as an argument, or use a pretrained model.
Same problem. Have you solved it?
Same problem. When I try running the eval.py I meet the Same problem. I see that after the func "policy = cls._load_as_safetensor(instance, model_file, config.device, strict)" in Diffusion policy part, the init vale 'inf' of "**Policy->normalize_inputs->buffer_observation_state->mean" have change, but in pi0 this 'inf' value have no change. Is the model.safetensors have some problem?
same problem when running pi0 on so100
return self._call_impl(*args, **kwargs) File "/home/nikhil/miniconda3/envs/lerobot/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) File "/home/nikhil/miniconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/nikhil/lerobot/lerobot/common/policies/normalize.py", line 169, in forward assert not torch.isinf(mean).any(), _no_stats_error_str("mean") AssertionError: `mean` is infinity. You should either initialize with `stats` as an argument, or use a pretrained model.
I am getting this same issue with the so100
Same problem. Have you solved it?
Same problem
When you finetune the pi0 on your own data, you can load the new ckpt, and there will not be any problem like this
Same problem
Sorry guys for not attending to this issue that I opened myself for a long time.
I think I've solved it, but welcome any corrections.
Basically, by design, pi0 is not for zero-shot inference. Their expected use case is that you would fine-tune your model on a dataset that's relevant to your own robot and task.
If you do the fine-tuning, then as @SresserS pointed out, you would have this stats field populated.
However, theoretically, since many commonly used datasets, such as bridge/fractal, are in the training mix of pi0. pi0 can do some limited amount of zero-shot inference on the said setup (for example, in SimplerEnv). To do so, I guess you can try to load the statistics of the said dataset manually and see how it works.
My understanding is that LeRobot's pi0 checkpoint is not trained but rather converted from PI's JAX checkpoint, so I guess this is why the stats field is missing. If this checkpoint is trained from scratch, then you should be able to do zero-shot inference with stats generated during training, (although the performance will be sub-par, as expected)
same problem ~~
Same problem, does that mean we have to load our own fine-tuned checkpoint?
same problem