openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Pi05 Weight mismatch between PyTorch and Jax

Open littlespray opened this issue 5 days ago • 0 comments

Hi all,

When testing with the following script, I found that the output of pi05’s weight before and after conversion is not the same given identical input. May I please ask what could be the possible reason?

I used the conversion command provided and after conversion, I copied the assets folder of the JAX weight to pi05_libero_pytorch/assets; otherwise, it would report a Norm stat not found error.

uv run examples/convert_jax_model_to_pytorch.py \
    --checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_libero \
    --config_name pi05_libero \
    --output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_libero_pytorch \

The script is:

import numpy as np
from openpi.policies import libero_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config

config = _config.get_config("pi05_libero")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_libero")

policy = _policy_config.create_trained_policy(config, checkpoint_dir)
example = libero_policy.make_libero_example()
result = policy.infer(example)
jax_actions = np.array(result["actions"])
del policy

pytorch_policy = _policy_config.create_trained_policy(config, "~/.cache/openpi/openpi-assets/checkpoints/pi05_libero_pytorch")
pytorch_result = pytorch_policy.infer(example)
pytorch_actions = np.array(pytorch_result["actions"])
del pytorch_policy

diff = jax_actions - pytorch_actions
max_diff = np.max(np.abs(diff))
mean_diff = np.mean(np.abs(diff))

print(f"Max Absolute Difference: {max_diff}")
print(f"Mean Absolute Difference: {mean_diff}")

if max_diff < 1e-5:
    print("The results are effectively identical.")
else:
    print("There are significant differences between the results.")

littlespray avatar Dec 01 '25 10:12 littlespray