Allow computer vision task models to run without tf.data
First step (of a few), to slowly relax our reliance on tf.data for preprocessing.
Our text models are more heavily reliant on tf.data because of the tf-text dependency. Our image models do not have this constraint.
We could try to allow running preprocessing without tf.data when running on the torch and jax backends. To do so, we would need to stop always converting to a tf.data.Dataset in our pipeline model helper here and find a way to still apply preprocessing to the iterator efficiently.
Contributions are welcome here, but this is a fairly abstract problem that would need some scouting out first. We could try to leverage Keras' DataAdapter here, I'm not sure how to best iterate over the dataset and apply the Keras layer. This is probably something best prototyped for a number of vision tasks first (classification, detection, segmentation).
@mattdangerw I would love to contribute to this issue. I think that we can relax our reliance on tf.data for computer vision tasks by converting inputs to NumPy arrays and then using a simple Python generator for batching and preprocessing. This approach should allow us to efficiently support Torch and JAX backends. While converting to NumPy arrays does add a bit of overhead, it's generally efficient when working with array-like inputs, and using a generator helps mitigate any performance impact by handling data in manageable batches. Let me know how this sounds.
Hi, @mattdangerw to clarify a bit:
-
Referring, for instance, to
ImageClassifierthepreprocessorlooks like the actual object used inpreprocess_samples, ifpreprocess_samplesis mainly a set of Keras layers. How about moving that in the model itself, so thatpreprocess_samplescan become an identity andPipelineModelcan then just become a normal model? I might miss something obvious here. Maybe this is not always so straightforward (maybe for text not even possible) -
At first, I also thought about
Keras DataAdapterwhich handles all the combinations officially supported (Numpy iterator, Torch Dataset, Pydataset ...), but to be "efficient" is not straightforward, I think would require change in Keras itself, as somehow you should try to inject those operations (the ones that you are running inpreprocess_samples), maybe overriding apost_processfunction which is not even a thing, at the minute, in the base data adapter class
Let me know your thoughts, thank you.
I was trying to run some tests to understand more, particularly MobileNetImageClassifierTest.
The pre-processing parts work fine with generic Keras ops that do not strictly rely on TF.
At some point, as a task extends PipelineModel we finish in PipelineModel where the hard dependency is (as you said originally in this thread).
Is the plan to support any type of input? Like Numpy, TfDataset, PyDataset and so on?
If so am thinking about monkey patching or something to inject these transformations on the fly. Not great but should work with 0 overhead
Hi @mattdangerw what are the next steps here?