torchopt icon indicating copy to clipboard operation
torchopt copied to clipboard

TorchRL - MAML integration

Open vmoens opened this issue 3 years ago • 8 comments
trafficstars

Brand-new version of the old PR

One can test this by running

python maml_torchrl.py --parallel --seed 1

I have added the --parallel flag, to allow users to have more than one environment used.

It can run on cuda if devices are present.

FYI: This should run on latest main in torchrl

vmoens avatar May 10 '22 15:05 vmoens

@vmoens Hey Vincent, I've adapted the maml_torchrl script using lastest torchrl and corrected my own MAML-RL implementation, but the results are still not aligned. Mind taking a look why that happened?

Thanks for this work! I see that MetaSGD has a lr=0.3 for torchrl but 0.1 for the other version, is that intended?

vmoens avatar Aug 06 '22 09:08 vmoens

@vmoens Hey Vincent, I've adapted the maml_torchrl script using lastest torchrl and corrected my own MAML-RL implementation, but the results are still not aligned. Mind taking a look why that happened?

Thanks for this work! I see that MetaSGD has a lr=0.3 for torchrl but 0.1 for the other version, is that intended?

Oops, I forgot to run the torchrl version of MAML with inner lr=0.1. I think we should lower the inner learning rate all to lr=0.1. The performance figure now is showing the results with lr=0.1.

Benjamin-eecs avatar Aug 06 '22 09:08 Benjamin-eecs

@vmoens Hey Vincent, I've adapted the maml_torchrl script using lastest torchrl and corrected my own MAML-RL implementation, but the results are still not aligned. Mind taking a look why that happened?

Thanks for this work! I see that MetaSGD has a lr=0.3 for torchrl but 0.1 for the other version, is that intended?

Oops, I forgot to run the torchrl version of MAML with inner lr=0.1. I think we should lower the inner learning rate all to lr=0.1. The performance figure now is showing the results with lr=0.1.

Sorry just to make sure I understood well: the PR indicates that lr=0.3 but the test was run with lr=0.1, right?

vmoens avatar Aug 06 '22 10:08 vmoens

@vmoens Hey Vincent, I've adapted the maml_torchrl script using lastest torchrl and corrected my own MAML-RL implementation, but the results are still not aligned. Mind taking a look why that happened?

Thanks for this work! I see that MetaSGD has a lr=0.3 for torchrl but 0.1 for the other version, is that intended?

Oops, I forgot to run the torchrl version of MAML with inner lr=0.1. I think we should lower the inner learning rate all to lr=0.1. The performance figure now is showing the results with lr=0.1.

Sorry just to make sure I understood well: the PR indicates that lr=0.3 but the test was run with lr=0.1, right?

Oh, I was mentioning the performance figure with my implementation. The torchrl performance figure in this PR was runned with lr=0.3. We may have to change the lr to 0.1 and run it again.

Benjamin-eecs avatar Aug 06 '22 10:08 Benjamin-eecs

@vmoens Hey Vincent, I've adapted the maml_torchrl script using lastest torchrl and corrected my own MAML-RL implementation, but the results are still not aligned. Mind taking a look why that happened?

Thanks for this work! I see that MetaSGD has a lr=0.3 for torchrl but 0.1 for the other version, is that intended?

I've already updated the results with lr=0.1, still not learning with torchrl implementation, mind taking a closer look at it?

Benjamin-eecs avatar Aug 06 '22 14:08 Benjamin-eecs

I've already updated the results with lr=0.1, still not learning with torchrl implementation, mind taking a closer look at it?

Of course. Has TorchOpt dropped cpu support? I used to be able to run it on my mcbook but it seems that now nvcc is needed. It's harder to debug on remote clusters...

cc @XuehaiPan

vmoens avatar Aug 06 '22 15:08 vmoens

Has TorchOpt dropped cpu support? I used to be able to run it on my mcbook but it seems that now nvcc is needed.

@vmoens If you are installing TorchOpt from source, you will need a nvcc to compile the .cu files. We provide a conda environment with fully-capable build toolchains.

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt

# This branch bumped torch == 1.12.1 and functorch == 0.2.1 (which were released yesterday)
git checkout release/0.4.3

CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml

conda activate torchopt
pip3 install --no-build-isolation --editable .

Note that you do not need a physical GPU installed on the machine to compile the source code.


Alternative, if you don't want to create a new conda environment or install TorchOpt from PyPI, you will need to build a new wheel (there was a bug in the wheels on PyPI, see #46 and #45). You can build a new wheel using cibuildwheel and docker:

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt

# This branch bumped torch == 1.12.1 and functorch == 0.2.1 (which were released yesterday)
git checkout release/0.4.3

python3 -m pip install --upgrade cibuildwheel  # you also need `docker` installed
CIBW_BUILD="cp39*manylinux*" python3 -m cibuildwheel --plat=linux --config-file=pyproject.toml

You can change cp39 in the variable CIBW_BUILD to your specific Python version (e.g. "cp38*manylinux*" for Python 3.8). A new wheel will list in the wheelhouse/ folder. Then install it by:

python3 -m pip install wheelhouse/torchopt-<tag>.whl

XuehaiPan avatar Aug 06 '22 16:08 XuehaiPan

Yes it is on train, just needs better sync on test wrt the seed setting. (At least on my machine, let me update the plot).

Results aren't quite different tbh, and we're still far from significant learning improvement in the original version (the error bars are hugely overlapping) I'm happy to do some more hyperparameter optimisation to match the other results better.

vmoens avatar Aug 11 '22 15:08 vmoens