imitation icon indicating copy to clipboard operation
imitation copied to clipboard

Serialize Dataset Save Not Working

Open alexpalms opened this issue 9 months ago • 1 comments

Bug description

Hi all, while trying to save locally on my filesystem a trajectories list I discovered that the save method of the serialize module is not working as expected, at list as presented in the docs.

When calling serialize.save("my_path", my_trajectories), the code fails with the following trace:

Downloading a pretrained expert.
Sampling expert transitions.
Traceback (most recent call last):
  File "/home/alexpalms/imitation_learning/imitation_quickstart.py", line 42, in <module>
    transitions = sample_expert_transitions()
  File "/home/alexpalms/imitation_learning/imitation_quickstart.py", line 38, in sample_expert_transitions
    serialize.save(rollouts_path, rollouts)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/imitation/data/serialize.py", line 25, in save
    huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 1515, in save_to_disk
    fs, _ = url_to_fs(dataset_path, **(storage_options or {}))
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/fsspec/core.py", line 383, in url_to_fs
    chain = _un_chain(url, kwargs)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/fsspec/core.py", line 323, in _un_chain
    if "::" in path

I also found the fix (or better workaround) but since I just came across this lib I am not sure if it is the best way to handle it, as it might hide some compatibility issues with HF library. To fix it, I did the following: I casted to string the path in this link: https://github.com/HumanCompatibleAI/imitation/blob/a8b079c469bb145d1954814f22488adff944aa0d/src/imitation/data/serialize.py#L23 so from this:

huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)

it became this:

huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))

Steps to reproduce

To reproduce the problem you can execute this code, a customization of your example:

import numpy as np
from imitation.data import rollout, serialize
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from pathlib import Path

rng = np.random.default_rng(0)
env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=rng,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollouts
)


def download_expert():
    print("Downloading a pretrained expert.")
    expert = load_policy(
        "ppo-huggingface",
        organization="HumanCompatibleAI",
        env_name="seals-CartPole-v0",
        venv=env,
    )
    return expert


def sample_expert_transitions():
    expert = download_expert()

    print("Sampling expert transitions.")
    rollouts = rollout.rollout(
        expert,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
        rng=rng,
    )
    rollouts_path = Path("./rollouts_path")
    serialize.save(rollouts_path, rollouts)
    return rollout.flatten_trajectories(rollouts)


transitions = sample_expert_transitions()

Environment

  • Operating system and version: Linux Mint 20.3 Una
  • Python version: 3.9.19
  • Output of pip freeze --all:
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
alembic==1.13.1
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.2.2
charset-normalizer==3.3.2
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.2
contourpy==1.2.1
cycler==0.12.1
datasets==2.19.1
dill==0.3.8
docopt==0.6.2
Farama-Notifications==0.0.4
filelock==3.14.0
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.3.1
gitdb==4.0.11
GitPython==3.1.43
greenlet==3.0.3
grpcio==1.63.0
gymnasium==0.29.1
huggingface-hub==0.23.0
huggingface-sb3==3.0
idna==3.7
imitation==1.0.0
importlib_metadata==7.1.0
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.0.4
kiwisolver==1.4.5
Mako==1.3.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
munch==4.0.0
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
optuna==3.6.1
packaging==24.0
pandas==2.2.2
pillow==10.3.0
pip==24.0
protobuf==5.26.1
py-cpuinfo==9.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pygame==2.5.2
Pygments==2.18.0
pyparsing==3.1.2
PyQt5==5.15.10
PyQt5-Qt5==5.15.2
PyQt5-sip==12.13.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.32.1
rich==13.7.1
sacred==0.8.5
scikit-learn==1.4.2
scipy==1.13.0
seals==0.2.1
setuptools==69.5.1
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.30
stable-baselines3==2.1.0
sympy==1.12
tensorboard==2.16.2
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
torch==2.3.0
tqdm==4.66.4
triton==2.3.0
typing_extensions==4.11.0
tzdata==2024.1
urllib3==2.2.1
wasabi==1.1.2
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

Let me know if you want me to create a PR and if you have suggestions to improve that handling.

Looking forward to receive your feedback

alexpalms avatar May 21 '24 04:05 alexpalms