keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Allow computer vision task models to run without tf.data

Open mattdangerw opened this issue 9 months ago • 5 comments

First step (of a few), to slowly relax our reliance on tf.data for preprocessing.

Our text models are more heavily reliant on tf.data because of the tf-text dependency. Our image models do not have this constraint.

We could try to allow running preprocessing without tf.data when running on the torch and jax backends. To do so, we would need to stop always converting to a tf.data.Dataset in our pipeline model helper here and find a way to still apply preprocessing to the iterator efficiently.

mattdangerw avatar Mar 06 '25 01:03 mattdangerw

Contributions are welcome here, but this is a fairly abstract problem that would need some scouting out first. We could try to leverage Keras' DataAdapter here, I'm not sure how to best iterate over the dataset and apply the Keras layer. This is probably something best prototyped for a number of vision tasks first (classification, detection, segmentation).

mattdangerw avatar Mar 06 '25 01:03 mattdangerw

@mattdangerw I would love to contribute to this issue. I think that we can relax our reliance on tf.data for computer vision tasks by converting inputs to NumPy arrays and then using a simple Python generator for batching and preprocessing. This approach should allow us to efficiently support Torch and JAX backends. While converting to NumPy arrays does add a bit of overhead, it's generally efficient when working with array-like inputs, and using a generator helps mitigate any performance impact by handling data in manageable batches. Let me know how this sounds.

arpitsinghgautam avatar Mar 22 '25 12:03 arpitsinghgautam

Hi, @mattdangerw to clarify a bit:

  • Referring, for instance, to ImageClassifier the preprocessor looks like the actual object used in preprocess_samples, if preprocess_samples is mainly a set of Keras layers. How about moving that in the model itself, so that preprocess_samples can become an identity and PipelineModel can then just become a normal model? I might miss something obvious here. Maybe this is not always so straightforward (maybe for text not even possible)

  • At first, I also thought about Keras DataAdapter which handles all the combinations officially supported (Numpy iterator, Torch Dataset, Pydataset ...), but to be "efficient" is not straightforward, I think would require change in Keras itself, as somehow you should try to inject those operations (the ones that you are running in preprocess_samples), maybe overriding a post_process function which is not even a thing, at the minute, in the base data adapter class

Let me know your thoughts, thank you.

edge7 avatar Mar 22 '25 14:03 edge7

I was trying to run some tests to understand more, particularly MobileNetImageClassifierTest. The pre-processing parts work fine with generic Keras ops that do not strictly rely on TF. At some point, as a task extends PipelineModel we finish in PipelineModel where the hard dependency is (as you said originally in this thread).

Is the plan to support any type of input? Like Numpy, TfDataset, PyDataset and so on?

If so am thinking about monkey patching or something to inject these transformations on the fly. Not great but should work with 0 overhead

edge7 avatar Mar 24 '25 16:03 edge7

Hi @mattdangerw what are the next steps here?

edge7 avatar Mar 29 '25 19:03 edge7