kedro-plugins icon indicating copy to clipboard operation
kedro-plugins copied to clipboard

Error when saving `TensorFlowModelDataset` as partition

Open anabelchuinard opened this issue 2 years ago • 4 comments

Description

Can't save TensorFlowModelDataset objects as partition.

Context

I am dealing with a project where I have to train several models concurrently. I started writing my code using PartitionedDataset where each partition corresponds to the data relative to one training. When I am trying to save the resulting tensorflow models as a partition, I get an error. I wonder is this has to do with the fact that those inherit from the AbstractVersionedDataset instead of the AbstractDataset. And if yes, I am interested to know if there is any workaround for batch saving those.

This is the instance of my catalog corresponding to the models I want to save:

tensorflow_models:
  type: PartitionedDataset
  path: data/derived/ML/models
  filename_suffix: ".hdf5"
  dataset:
    type: kedro.extras.datasets.tensorflow.TensorFlowModelDataset

Note: Saving one model (not as partition) works.

Steps to Reproduce

  1. Generate a bunch of trained models
  2. Try to save them in a partition as TensorFlowModelDataset objects

Expected Result

Should save one .hdf5 file per partition with the name of the file being the associate dictionary key.

Actual Result

Getting this error:

DatasetError: Failed while saving data to data set PartitionedDataset(dataset_config={}, dataset_type=TensorFlowModelDataset,
path=...).
The first argument to `Layer.call` must always be passed.

Your Environment

  • Kedro version used (pip show kedro or kedro -V): kedro, version 0.18.12
  • Python version used (python -V): 3.9.16
  • Operating system and version: Mac M2

anabelchuinard avatar Aug 21 '23 22:08 anabelchuinard

Hi @anabelchuinard, thanks for opening this issue and sorry for the delay. It will take us some time but I'm labeling this issue so we don't lose track of it.

astrojuanlu avatar Sep 05 '23 11:09 astrojuanlu

Hi @anabelchuinard, do you still need help fixing this issue?

merelcht avatar Jul 08 '24 14:07 merelcht

@merelcht I found a non-kedronic workaround for this but would love to know if there is now a kedronic way for batch-saving those models.

anabelchuinard avatar Jul 08 '24 17:07 anabelchuinard

Using the PartitionedDataset is definitely the recommended Kedro way for batch saving. I've done some digging and it seems that the following lines are causing issues for using the TensorFlowModelDataset with PartitionedDataset:

https://github.com/kedro-org/kedro-plugins/blob/be99fecf6cf5ac8f6a0a717c56b06dbc148b26eb/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py#L313-L314

merelcht avatar Jul 09 '24 13:07 merelcht

Cause of the issue

The issue is in how we implement partitioned dataset lazy saving. To postpone data loading, we require return Callable types in the dictionary fed to PartitionedDataset instead of the actual object.

https://github.com/kedro-org/kedro-plugins/blob/be99fecf6cf5ac8f6a0a717c56b06dbc148b26eb/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py#L313-L314

When saving the data, we check if the Callable type was passed and call it to get the actual object. Since the TensorFlow model is callable, we make this call when saving, which causes the above error, though the user didn't mean to apply lazy saving.

So PartitionedDataset cannot save Callable types now, unless they're wrapped with another Callable, for example, lambda.

Current workaround

@anabelchuinard - To make PartitionedDataset save Callable in the current Kedro version you need to wrap an object as if you wanted to do a lazy saving:


save_dict = {
	"tensorflow_model_32": models["tensorflow_model_32"](),
	"tensorflow_model_64": models["tensorflow_model_64"](),
}

# Tensorflow model can be wrapped with lambda, to avoid calling it when saving
save_dict = {
	"tensorflow_model_32": lambda: models["tensorflow_model_32"](),
	"tensorflow_model_64": lambda: models["tensorflow_model_64"](),
}

Suggested fix

Make PartitionedDataset accept only lambda functions for lazy saving and ignore other callable objects - https://github.com/kedro-org/kedro-plugins/pull/978

Following PR to update docs

https://github.com/kedro-org/kedro/pull/4402

ElenaKhaustova avatar Jan 07 '25 15:01 ElenaKhaustova

Suggested fix Make PartitionedDataset accept only lambda functions for lazy saving and ignore other callable objects - https://github.com/kedro-org/kedro-plugins/pull/978

To me this seems to be a niche case, and changing PartitionedDataset to only accept lambda is a bigger breaking change. Any useful callable will likely be more complicated than a simple lambda. Maybe we can disable lazy loading/saving (default enable) when specified?

noklam avatar Jan 07 '25 16:01 noklam

Suggested fix Make PartitionedDataset accept only lambda functions for lazy saving and ignore other callable objects - #978

To me this seems to be a niche case, and changing PartitionedDataset to only accept lambda is a bigger breaking change. Any useful callable will likely be more complicated than a simple lambda. Maybe we can disable lazy loading/saving (default enable) when specified?

I see the point but I think the issue is a little bit broader than this case. Particularly I don't think it's right to call any callable object and use this check to decide if we apply lazy saving. This affects all the ml-models cases (tensorflow, pytorch, scikit-learn, etc.) and potentially can also execute some unwanted code implemented in __call__. Moreover, it's not intuitive for users to wrap their objects to avoid such a behaviour.

In the solution suggested I tried to narrow down these cases from callable to lamda, so there's less chance to get them.

As an alternative, we can consider making lazy saving a default behaviour so we internally wrap and unwrap objects automatically. But here, the question is whether we need to make it the only option (as it is for lazy loading) or provide some interface to disable it.

ElenaKhaustova avatar Jan 08 '25 11:01 ElenaKhaustova

Thanks for the investigation and PR, @ElenaKhaustova! I agree with @noklam that relying solely on lambda functions for lazy saving doesn't seem like a generic solution. While it is a breaking change, it's hard to determine how much it will impact users. In my opinion, it would be better to avoid treating all Callables as participants in lazy saving by default. However, this would also be a breaking change. As a simpler alternative, we could provide an option to disable lazy saving, as you suggested.

DimedS avatar Jan 08 '25 17:01 DimedS

@noklam, @DimedS, @astrojuanlu

Based on the above arguments, my suggestion would be to make lazy saving a default behaviour like it's done for lazy loading now. For that, we can wrap and unwrap objects internally (instead of asking users to do so manually like we do now), which will guarantee that the Callable we get is expected to be called.

The other question is whether we should provide an option to disable lazy saving. Are there any known cases when disabling it might be critical? Note that we don't have such an option for lazy loading, so it's always enabled.

Please see the edited suggestion below.

ElenaKhaustova avatar Jan 10 '25 14:01 ElenaKhaustova

Based on the above arguments, my suggestion would be to make lazy saving a default behaviour like it's done for lazy loading now. For that, we can wrap and unwrap objects internally (instead of asking users to do so manually like we do now), which will guarantee that the Callable we get is expected to be called.

Hi @ElenaKhaustova, Could you please explain how lazy saving will work? For instance, if I want to enable lazy saving and have a function in one partition that executes some code and returns a pandas DataFrame, how should I modify my function to align with your proposal of wrapping all partitions?

DimedS avatar Jan 14 '25 10:01 DimedS

@DimedS

Could you please explain how lazy saving will work?

I think the easiest way with minimal changes will be to add lazy argument to save() function with True default value:

def save(self, data: dict[str, Any], lazy: bool =True) -> None:

Then:

  • If the input object is callable and lazy=True we unwrap it
  • If the input object is not callable and lazy=True we do nothing
  • If the input object is callable and lazy=False we do nothing
  • If the input object is not callable and lazy=False we do nothing
  • Lazy saving will be enabled by default similar to lazy loading

So a user will still need to wrap the object as it was required before and this behaviour won't change. But there will be a proper option to disable it. Now in case of working will callable, like a TF model, one needs to wrap it to avoid its calling: https://github.com/kedro-org/kedro-plugins/issues/759#issuecomment-2575562106

ElenaKhaustova avatar Jan 14 '25 11:01 ElenaKhaustova

Thanks, @ElenaKhaustova! If I understand correctly, the default behavior will remain the same as the current one. However, we are adding the option to use lazy=False, which will prevent Callables from being unwrapped, allowing users to apply it in scenarios like TensorFlow. Is that correct? If so, I really like this idea!

DimedS avatar Jan 14 '25 12:01 DimedS

Thanks, @ElenaKhaustova! If I understand correctly, the default behavior will remain the same as the current one. However, we are adding the option to use lazy=False, which will prevent Callables from being unwrapped, allowing users to apply it in scenarios like TensorFlow. Is that correct? If so, I really like this idea!

Yes, exactly! That's the way to avoid a breaking change.

ElenaKhaustova avatar Jan 14 '25 12:01 ElenaKhaustova

I think the easiest way with minimal changes will be to add lazy argument to save() function with True default value:

def save(self, data: dict[str, Any], lazy=True) -> None:

This sounds like a good and clean solution to me.

merelcht avatar Jan 14 '25 13:01 merelcht

@ElenaKhaustova Would users be able to toggle that from catalog.yml?

astrojuanlu avatar Jan 14 '25 15:01 astrojuanlu

@ElenaKhaustova Would users be able to toggle that from catalog.yml?

Yes, I think we should also add it to make sure disabling is possible not only programmatically but with kedro run as well.

ElenaKhaustova avatar Jan 14 '25 16:01 ElenaKhaustova