openpi
openpi copied to clipboard
Pi05 Weight mismatch between PyTorch and Jax
Hi all,
When testing with the following script, I found that the output of pi05’s weight before and after conversion is not the same given identical input. May I please ask what could be the possible reason?
I used the conversion command provided and after conversion, I copied the assets folder of the JAX weight to pi05_libero_pytorch/assets; otherwise, it would report a Norm stat not found error.
uv run examples/convert_jax_model_to_pytorch.py \
--checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_libero \
--config_name pi05_libero \
--output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_libero_pytorch \
The script is:
import numpy as np
from openpi.policies import libero_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
config = _config.get_config("pi05_libero")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_libero")
policy = _policy_config.create_trained_policy(config, checkpoint_dir)
example = libero_policy.make_libero_example()
result = policy.infer(example)
jax_actions = np.array(result["actions"])
del policy
pytorch_policy = _policy_config.create_trained_policy(config, "~/.cache/openpi/openpi-assets/checkpoints/pi05_libero_pytorch")
pytorch_result = pytorch_policy.infer(example)
pytorch_actions = np.array(pytorch_result["actions"])
del pytorch_policy
diff = jax_actions - pytorch_actions
max_diff = np.max(np.abs(diff))
mean_diff = np.mean(np.abs(diff))
print(f"Max Absolute Difference: {max_diff}")
print(f"Mean Absolute Difference: {mean_diff}")
if max_diff < 1e-5:
print("The results are effectively identical.")
else:
print("There are significant differences between the results.")