tfx-addons icon indicating copy to clipboard operation
tfx-addons copied to clipboard

TFX + PyTorch Example

Open hanneshapke opened this issue 2 years ago • 10 comments

There are a few TFX examples for how to train Scikit learn or JAX models, I haven't seen an example pipeline for PyTorch.

The pipeline could use a known dataset, e.g. MNIST, ingest the data via the CSVExampleGen, run the standard statistics and schema steps, performs a pseudo transformation (passthrough of the values) with the new PandasTransform component from tfx-addons, add a custom run_fn function for PyTorch, and then add a TFMA example.

Any thoughts?

hanneshapke avatar Aug 02 '22 23:08 hanneshapke

Proposal for the TFX Addons Example: https://github.com/tensorflow/tfx-addons/pull/157

hanneshapke avatar Aug 03 '22 15:08 hanneshapke

Yes, please. If possible, let's demonstrate it with a model from Hugging Face with PT backend.

sayakpaul avatar Aug 23 '22 07:08 sayakpaul

One of the things that we will need for this is an ONNX extractor for Evaluator. Maybe we should break that out as a separate project?

rcrowe-google avatar Mar 10 '23 22:03 rcrowe-google

Could you elaborate this a bit more?

sayakpaul avatar Mar 11 '23 00:03 sayakpaul

One of the things that we will need for this is an ONNX extractor for Evaluator. Maybe we should break that out as a separate project?

Could you elaborate this a bit more?

My understanding is that for PyTorch developers ONNX is a normal format for saving trained models, while TF's SavedModel format introduces friction. For non-SavedModel models Evaluator needs an Extractor in order to generate predictions to measure. For example, the one for Sklearn and the one for XGBoost

rcrowe-google avatar Mar 11 '23 01:03 rcrowe-google

ONNX is definitely used but I am not sure that is a normal one like you mentioned. This document gives a good rundown of the serialization semantics in PyTorch: https://pytorch.org/docs/stable/notes/serialization.html

ONNX is definitely quite popularly used there (PyTorch has a direct ONNX exporter too). From what I am gathering here is that we make ONNX the serialization format for the PyTorch models to make them work in a TFX pipeline. Is that so?

sayakpaul avatar Mar 11 '23 01:03 sayakpaul

... we make ONNX the serialization format for the PyTorch models

My thought is more one of the serialization formats, which to me suggests that breaking it out as a separate project might make sense. We could also do Extractors for TensorRT, TorchScript, or whatever makes sense (and here I'm displaying my ignorance about what makes sense) and let users choose the one they need.

rcrowe-google avatar Mar 11 '23 02:03 rcrowe-google

Got it. Yeah I concur with your thoughts now.

Moreover, the reason it might make even more sense is because users might want to choose an Extractor in accordance with their deployment infra. For example, ONNX might be better for CPU-based deployment while TensorRT would be better suited for a GPU-based runtime (although ONNX can handle TensorRT as a backend as well).

sayakpaul avatar Mar 11 '23 02:03 sayakpaul

I think Wihan wrote a custom TFMA extractor for PyTorch. We had everything done up to the trainer when we shared the notebook with Wihan. Last time, we talked he was in the process of cleaning up his implementation. He said it worked end-to-end.

hanneshapke avatar Mar 11 '23 05:03 hanneshapke

I think Wihan wrote a custom TFMA extractor for PyTorch. We had everything done up to the trainer when we shared the notebook with Wihan. Last time, we talked he was in the process of cleaning up his implementation. He said it worked end-to-end.

@wihanbooyse - That would be great! It might make sense to refactor the example to break out the extractor separately, and follow that up with some more extractors for other formats.

rcrowe-google avatar Mar 11 '23 23:03 rcrowe-google