jsonargparse icon indicating copy to clipboard operation
jsonargparse copied to clipboard

Extend docs to explain how to make jsonargparse work with pydantic models

Open bzfhille opened this issue 3 years ago • 5 comments

This is about using jsonargparse within LightningCLI to generate a nice CLI easily, but the use case is likely more general.

pydantic is a fantastic way to use structured types in a type-safe way with a lot of other benefits. So I am using it in a project to structure the parameters of a complex ML model. In particular, the parameter dataset_params of the LightningDataModule is of type dataset.DatasetParams, which happens to be a subclass of pydantic.BaseModel to get all the niceties. Running the really simple basic script for a trainer

from pytorch_lightning.utilities.cli import LightningCLI
from models import Model
from dataset import DataModule  # This is the LightningDataModule.

cli = LightningCLI(Model, DataModule)

via

# naive way to provide the dict to be maybe parsed using pydantic
python train.py --data.dataset_params="{'filename': features.hdf5}"

yields the error message

train.py: error: Parser key "data.dataset_params": Type <class 'dataset.DatasetParams'> expects an str or a Dict/Namespace with a class_path entry but got "{'filename': 'features.hdf5'}"

I found an easy solution around this that is not documented (well). Although jsonargparse does not directly support pydantic, it does support dataclasses. Now pydantic has a nice "compatibility mode for dataclasses", that allows to have "standard" dataclass instead of a subclass of pydantic.BaseModel. The nice thing that it is sufficient to apply this trick to dataset.DatasetParams, not to the other nested types that are subclasses of pydantic.BaseModel. So the following makes the above simple script work:

from pydantic.dataclasses import dataclass
# ModelParams is a subclass of pydantic.BaseModel (that has fields that are themselves
# subclasses of pydantic.BaseModel).
from models import ModelParams

# Use pydantic.dataclasses.dataclass here instead of deriving from pydantic.BaseModel to make the magic work.
@dataclass
class DatasetParams:
    """Parameters for accessing a dataset."""
	filename: str
    model_params: ModelParams = ModelParams()

where the call is now as simple as expected:

python train.py --data.dataset_params.filename=features.hdf5

I request that this is somehow mentioned explicitly in the docs since it is really useful.

Thank you for the excellent work that provided exactly what I was looking for.

bzfhille avatar Nov 18 '21 19:11 bzfhille

Well, I was a bit too enthusiastic. To get the pydantic models as nested parameters in the config file, it is necessary to use the @dataclass decorator instead of BaseModel throughout. But this works really nicely.

bzfhille avatar Nov 18 '21 19:11 bzfhille

Thank you for reporting. I have not used pydantic with jsonargparse as you have. I will try it out to understand all the details and document it accordingly.

mauvilsa avatar Nov 22 '21 08:11 mauvilsa

I think that mentioning pydantic and refering to the relevant parts of its docs will be sufficient or at least help a lot. If you like, I can draft a PR later this week.

bzfhille avatar Nov 22 '21 08:11 bzfhille

I think that mentioning pydantic and refering to the relevant parts of its docs will be sufficient or at least help a lot. If you like, I can draft a PR later this week.

Certainly, you are welcome to contribute.

mauvilsa avatar Nov 23 '21 08:11 mauvilsa

@bzfhille will you work on this?

mauvilsa avatar Jun 20 '22 06:06 mauvilsa

I'm also very interested in this topic, as I see pydantic as the only missing piece of jsonargparse for my config handling (particularly interested to use pydantic run-time validation of parameters through a dataclass schema, rather than type checking like mypy)

vedal avatar Nov 01 '22 18:11 vedal

Hi @vedal. I haven't had time to look at pydantic. From the comments above jsonargparse already works with pydantic when using standard dataclasses, which is the link that you posted. Did you try it? What is missing is documentation saying that this works and what needs to be done to achieve it.

mauvilsa avatar Nov 08 '22 06:11 mauvilsa

@mauvilsa Thanks for your response. I finally had some time to try this out, and realized what was missing in the docs for me was an explanation on how to combine CLI with dataclass in general, which would generalize to pydantic.

The missing example would describe how, if possible, to parse config directly using dataclass with CLI in this example, instead of using low level parsers. Or is the combination CLI + dataclass not supported?

vedal avatar Dec 08 '22 08:12 vedal

is the combination CLI + dataclass not supported?

It is supported. Just note that CLI without subcommands receives a function. The parameters of the function could have as type a dataclass, creating a nested namespace. Or could be to give CLI directly a dataclass which would create subcommands. Though that kind of abuses what a dataclass is supposed to be, so I don't like this idea.

Your proposal sounds good, but better if it is not in the variable interpolation section, since that would mix topics. There is a previous dataclass example in nested-namespaces. How about doing it there? Even better would be to split that section. The nested namespaces section would be very short then there would be a section called dataclasses. It could first show an example with CLI and then for completion show that dataclasses can also be used in low level parsers. The examples could be changed to be more realistic instead of level 1. Could be the same example as in the interpolation section, just without using interpolation.

mauvilsa avatar Dec 10 '22 07:12 mauvilsa

This didn't work for me. The code below works with

from dataclasses import dataclass

but not with

from pydantic.dataclasses import dataclass

import pytorch_lightning as pl

from pydantic.dataclasses import dataclass
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning.cli import LightningCLI


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


@dataclass
class Config:
    name: str


class Module(pl.LightningModule):
    def __init__(self, *, cfg: Config):
        super().__init__()

    def training_step(self, *args):
        pass

    def train_dataloader(self, *args):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def configure_optimizers(self, *args):
        pass


if __name__ == "__main__":
    print(LightningCLI(Module))

Command: python main.py fit --model.cfg.name test --print_config

Error: ValueError: Invalid or unsupported input: class=<class '__main__.Config'>, method_or_property=__init__

Versions:

pytorch-lightning         1.7.7              pyhd8ed1ab_0    conda-forge
jsonargparse              4.19.0             pyhd8ed1ab_0    conda-forge
pydantic                  1.10.4          py310h90acd4f_1    conda-forge

karlgem avatar Jan 30 '23 11:01 karlgem

@karlgem the issue that you reported in https://github.com/omni-us/jsonargparse/issues/100#issuecomment-1408413796 has been fixed in pull request #266. In the same pull request, support for pydantic's BaseModel classes has been added. The idea is that things just work without much need to explain much in the docs. For now I will close this issue. Feel free to try things out by installing the code currently in the master branch. Please open new issues for any issues encountered.

mauvilsa avatar Apr 12 '23 20:04 mauvilsa