stable-baselines3
stable-baselines3 copied to clipboard
Add np.ndarray as a recognized type for TB histograms.
Currently the SB3 tensorboard writer only supports torch.Tensor as a histogram value. However, the SummaryWriter actually also allows np.ndarray as a value. This PR enables this.
Motivation and Context
Closes #1634
- [x] I have raised an issue to propose this change (required for new features and bug fixes)
Types of changes
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
Checklist
- [x] I've read the CONTRIBUTION guide (required)
- [x] I have updated the changelog accordingly (required).
- [ ] My change requires a change to the documentation.
- [x] I have updated the tests accordingly (required for a bug fix or a new feature).
- [ ] I have updated the documentation accordingly.
- [ ] I have opened an associated PR on the SB3-Contrib repository (if necessary)
- [ ] I have opened an associated PR on the RL-Zoo3 repository (if necessary)
- [x] I have reformatted the code using
make format(required) - [x] I have checked the codestyle using
make check-codestyleandmake lint(required) - [x] I have ensured
make pytestandmake typeboth pass. (required) - [x] I have checked that the documentation builds using
make doc(required)
Note: You can run most of the checks using make commit-checks.
Note: we are using a maximum length of 127 characters per line
This has been fixed in pytorch v2.0.0 and I'll look into how to get this working for earlier versions for this repo, if at all possible, to ensure compatibility with pytorch>=1.13.0 as it currently is. According to the numpy docs, this is a deprecation issue.
From my testing, this works perfectly fine with numpy 1.23.0 and torch 1.13.1
Proof that it's numpy
numpy 1.23.0
# Dockerfile
FROM python:3.11
COPY ./setup.py /src/setup.py
COPY ./stable_baselines3/version.txt /src/stable_baselines3/version.txt
WORKDIR /src
RUN pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html \
numpy==1.23.0 \
tensorboard \
.[tests]
CMD /bin/bash
$ docker build . -t sb3-dev -f Dockerfile
$ docker run -v $PWD:/src/stable-baselines3 --rm sb3-dev python -m pytest /src/stable-baselines3/tests/test_logger.py
============================= test session starts ==============================
platform linux -- Python 3.11.4, pytest-7.4.0, pluggy-1.2.0
rootdir: /src/stable-baselines3
configfile: pyproject.toml
plugins: cov-4.1.0, xdist-3.3.1, env-0.8.2
collected 50 items
stable-baselines3/tests/test_logger.py ................................. [ 66%]
................. [100%]
=============================== warnings summary ===============================
../usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if not hasattr(tensorboard, "__version__") or LooseVersion(
../usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
) < LooseVersion("1.15"):
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram0]
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram1]
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/summary.py:386: DeprecationWarning: using `dtype=` in comparisons is only useful for `dtype=object` (and will do nothing for bool). This operation will fail in the future.
cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 50 passed, 7 warnings in 2.43s ========================
numpy 1.24.0
# Dockerfile
FROM python:3.11
COPY ./setup.py /src/setup.py
COPY ./stable_baselines3/version.txt /src/stable_baselines3/version.txt
WORKDIR /src
RUN pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html \
numpy==1.23.0 \
tensorboard \
.[tests]
CMD /bin/bash
$ docker build . -t sb3-dev -f Dockerfile
$ docker run -v $PWD:/src/stable-baselines3 --rm sb3-dev python -m pytest /src/stable-baselines3/tests/test_logger.py
============================= test session starts ==============================
platform linux -- Python 3.11.4, pytest-7.4.0, pluggy-1.2.0
rootdir: /src/stable-baselines3
configfile: pyproject.toml
plugins: cov-4.1.0, xdist-3.3.1, env-0.8.2
collected 50 items
stable-baselines3/tests/test_logger.py ......F...........FF............. [ 66%]
................. [100%]
=================================== FAILURES ===================================
Relevant changes in pytorch
v1.13.1 https://github.com/pytorch/pytorch/blame/49444c3e546bf240bed24a101e747422d1f8a0ee/torch/utils/tensorboard/summary.py#L386
v2.0.0 https://github.com/pytorch/pytorch/blame/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/tensorboard/summary.py#L383
@araffin I could just wrap the code in a try-catch until SB3 supports torch >= 2.0.0?
Something like
try:
self.writer.add_histogram(key, value, step)
except TypeError:
pass
which would still work in the original manner, whilst letting people with newer versions of torch leverage this feature. Then open a tracker issue to ensure it's not forgotten about.
Something like
Probably cast to torch tensor automatically (using from_numpy()) and output a warning too?
@araffin Good idea. Okay that should be that fixed with your suggestions implemented. I tested the new proposed solution in the same manner as outlined here and I saw the warning (the deprecation from numpy and from the try/except) but the tests passed. Running test_logger.py with coverage enabled showed that all branches are being hit which should bullet proof this solution against regression. Once SB3 supports torch>=2.0.0 the relevant code can be reverted back to d37a952.
Log:
$ docker run -v $PWD:/src/stable-baselines3 --rm sb3-dev python -m pytest /src/stable-baselines3/tests/test_logger.py
============================= test session starts ==============================
platform linux -- Python 3.11.5, pytest-7.4.0, pluggy-1.3.0
rootdir: /src/stable-baselines3
configfile: pyproject.toml
plugins: cov-4.1.0, env-1.0.1, xdist-3.3.1
collected 51 items
stable-baselines3/tests/test_logger.py ................................. [ 64%]
.................. [100%]
=============================== warnings summary ===============================
../usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if not hasattr(tensorboard, "__version__") or LooseVersion(
../usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
) < LooseVersion("1.15"):
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_make_output[tensorboard]
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram0-False]
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram1-False]
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram2-True]
/usr/local/lib/python3.11/site-packages/torch/utils/tensorboard/summary.py:386: DeprecationWarning: using `dtype=` in comparisons is only useful for `dtype=object` (and will do nothing for bool). This operation will fail in the future.
cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
tests/test_logger.py::test_report_histogram_to_tensorboard[histogram2-True]
/src/stable-baselines3/stable_baselines3/common/logger.py:419: UserWarning: A numpy.ndarray was passed to write which threw a TypeError. This is most likely due to an outdated numpy version (<1.24.0) and/or an outdated torch version (<2.0.0). The ndarray will be converted to a torch.Tensor as a workaround. For more information, see https://github.com/DLR-RM/stable-baselines3/pull/1635
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 51 passed, 9 warnings in 3.58s ========================
@araffin polite request to revisit this PR please