rl icon indicating copy to clipboard operation
rl copied to clipboard

[Feature] adding tensor classes annotation for loss functions

Open SandishKumarHN opened this issue 1 year ago • 4 comments

Description

followup from this pull request copy past:

We project on using https://github.com/Tensorclass to represent losses.

The advantage of tensorclass for losses instead of tensordict is that it will help us use all the features of tensordict while preserving type annotation or even completion.

Changes: Check the out_keys of the loss; Create a tensorclass with the respective fields; Type the forward as returning that class (and/or a tensordict) Add an argument to return the class in the constructor with the False value by default; Update the docstrings (not done) Write a little test to check that things work as expected (this test should be new and not parametrized - if we add one more parameter to the existing tests the code will be much longer and harder to follow, and the tests will run for a long time).

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds core functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Documentation (update in the documentation)
  • [ ] Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • [x] I have read the CONTRIBUTION guide (required)
  • [ ] My change requires a change to the documentation.
  • [x] I have updated the tests accordingly (required for a bug fix or a new feature).
  • [ ] I have updated the documentation accordingly.

SandishKumarHN avatar Feb 13 '24 00:02 SandishKumarHN

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1905

Note: Links to docs will display an error until the docs builds have been completed.

:x: 14 New Failures

As of commit 9b5f4e64ac273260c612c153018c728c5a7817a1 with merge base 87f3437b26a8841e534b62ef6aa020d5fc287a90 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Feb 13 '24 00:02 pytorch-bot[bot]

@vmoens can you review once, build errors on resource not related to the PR.

SandishKumarHN avatar Feb 24 '24 05:02 SandishKumarHN

@vmoens address most of your comments above, but doctests are failing with below error not caused by this PR changes.

File "/home/sandish/rl/torchrl/objectives/cql.py", line 128, in cql.CQLLoss
Failed example:
    loss = CQLLoss(actor, qvalue)
Exception raised:
    Traceback (most recent call last):
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/doctest.py", line 1334, in __run
        exec(compile(example.source, filename, "single",
      File "<doctest cql.CQLLoss[16]>", line 1, in <module>
        loss = CQLLoss(actor, qvalue)
      File "/home/sandish/rl/torchrl/objectives/cql.py", line 321, in __init__
        self.convert_to_functional(
      File "/home/sandish/rl/torchrl/objectives/common.py", line 289, in convert_to_functional
        params.apply(
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/nn/params.py", line 125, in new_func
        out = meth(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/base.py", line 3824, in apply
        return self._apply_nest(
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/_td.py", line 659, in _apply_nest
        out = TensorDict(
    TypeError: __init__() got an unexpected keyword argument 'filter_empty'
  File "/pytorch/rl/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1704, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DDPGLoss' object has no attribute 'reduction'

SandishKumarHN avatar Feb 29 '24 19:02 SandishKumarHN

@vmoens made changes based on your review, I still reduction is not being added to the test_cost.py file so all of the failures are related to that.

SandishKumarHN avatar Mar 14 '24 20:03 SandishKumarHN