jsonargparse
jsonargparse copied to clipboard
Extend docs to explain how to make jsonargparse work with pydantic models
This is about using jsonargparse
within LightningCLI to generate a nice CLI easily, but the use case is likely more general.
pydantic
is a fantastic way to use structured types in a type-safe way with a lot of other benefits. So I am using it in a project to
structure the parameters of a complex ML model. In particular, the parameter dataset_params
of the LightningDataModule
is of type dataset.DatasetParams
, which happens to be a subclass of pydantic.BaseModel
to get all the niceties. Running the really
simple basic script for a trainer
from pytorch_lightning.utilities.cli import LightningCLI
from models import Model
from dataset import DataModule # This is the LightningDataModule.
cli = LightningCLI(Model, DataModule)
via
# naive way to provide the dict to be maybe parsed using pydantic
python train.py --data.dataset_params="{'filename': features.hdf5}"
yields the error message
train.py: error: Parser key "data.dataset_params": Type <class 'dataset.DatasetParams'> expects an str or a Dict/Namespace with a class_path entry but got "{'filename': 'features.hdf5'}"
I found an easy solution around this that is not documented (well). Although jsonargparse
does not directly support pydantic
, it does
support dataclasses
. Now pydantic
has a nice "compatibility mode for dataclasses
", that allows to have "standard" dataclass instead of a subclass of pydantic.BaseModel
. The nice thing that it is sufficient to apply this trick to dataset.DatasetParams
, not to the other nested types that are subclasses of pydantic.BaseModel
. So the following makes the above simple script work:
from pydantic.dataclasses import dataclass
# ModelParams is a subclass of pydantic.BaseModel (that has fields that are themselves
# subclasses of pydantic.BaseModel).
from models import ModelParams
# Use pydantic.dataclasses.dataclass here instead of deriving from pydantic.BaseModel to make the magic work.
@dataclass
class DatasetParams:
"""Parameters for accessing a dataset."""
filename: str
model_params: ModelParams = ModelParams()
where the call is now as simple as expected:
python train.py --data.dataset_params.filename=features.hdf5
I request that this is somehow mentioned explicitly in the docs since it is really useful.
Thank you for the excellent work that provided exactly what I was looking for.
Well, I was a bit too enthusiastic. To get the pydantic models as nested parameters in the config file, it is necessary to use the @dataclass
decorator instead of BaseModel
throughout. But this works really nicely.
Thank you for reporting. I have not used pydantic with jsonargparse as you have. I will try it out to understand all the details and document it accordingly.
I think that mentioning pydantic and refering to the relevant parts of its docs will be sufficient or at least help a lot. If you like, I can draft a PR later this week.
I think that mentioning pydantic and refering to the relevant parts of its docs will be sufficient or at least help a lot. If you like, I can draft a PR later this week.
Certainly, you are welcome to contribute.
@bzfhille will you work on this?
I'm also very interested in this topic, as I see pydantic as the only missing piece of jsonargparse for my config handling (particularly interested to use pydantic run-time validation of parameters through a dataclass schema, rather than type checking like mypy)
Hi @vedal. I haven't had time to look at pydantic. From the comments above jsonargparse already works with pydantic when using standard dataclasses, which is the link that you posted. Did you try it? What is missing is documentation saying that this works and what needs to be done to achieve it.
@mauvilsa Thanks for your response. I finally had some time to try this out, and realized what was missing in the docs for me was an explanation on how to combine CLI
with dataclass in general, which would generalize to pydantic.
The missing example would describe how, if possible, to parse config directly using dataclass with CLI
in this example, instead of using low level parsers. Or is the combination CLI + dataclass not supported?
is the combination CLI + dataclass not supported?
It is supported. Just note that CLI without subcommands receives a function. The parameters of the function could have as type a dataclass, creating a nested namespace. Or could be to give CLI directly a dataclass which would create subcommands. Though that kind of abuses what a dataclass is supposed to be, so I don't like this idea.
Your proposal sounds good, but better if it is not in the variable interpolation section, since that would mix topics. There is a previous dataclass example in nested-namespaces. How about doing it there? Even better would be to split that section. The nested namespaces section would be very short then there would be a section called dataclasses. It could first show an example with CLI and then for completion show that dataclasses can also be used in low level parsers. The examples could be changed to be more realistic instead of level 1. Could be the same example as in the interpolation section, just without using interpolation.
This didn't work for me. The code below works with
from dataclasses import dataclass
but not with
from pydantic.dataclasses import dataclass
import pytorch_lightning as pl
from pydantic.dataclasses import dataclass
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.cli import LightningCLI
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
@dataclass
class Config:
name: str
class Module(pl.LightningModule):
def __init__(self, *, cfg: Config):
super().__init__()
def training_step(self, *args):
pass
def train_dataloader(self, *args):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def configure_optimizers(self, *args):
pass
if __name__ == "__main__":
print(LightningCLI(Module))
Command:
python main.py fit --model.cfg.name test --print_config
Error:
ValueError: Invalid or unsupported input: class=<class '__main__.Config'>, method_or_property=__init__
Versions:
pytorch-lightning 1.7.7 pyhd8ed1ab_0 conda-forge
jsonargparse 4.19.0 pyhd8ed1ab_0 conda-forge
pydantic 1.10.4 py310h90acd4f_1 conda-forge
@karlgem the issue that you reported in https://github.com/omni-us/jsonargparse/issues/100#issuecomment-1408413796 has been fixed in pull request #266. In the same pull request, support for pydantic's BaseModel classes has been added. The idea is that things just work without much need to explain much in the docs. For now I will close this issue. Feel free to try things out by installing the code currently in the master branch. Please open new issues for any issues encountered.