flytekit
flytekit copied to clipboard
TypeTransformer for reading and writing from TensorFlowRecord format
TL;DR
This flyte feature adds support for users to read and write from .tfrecord
file formats
using Tensorflow Example as a native type.
Type
- [ ] Bug Fix
- [x] Feature
- [ ] Plugin
Are all requirements met?
- [x] Code completed
- [ ] Smoke tested
- [x] Unit tests added
- [x] Code documentation added
- [ ] Any pending items have an associated Issue
Complete description
- Adds a
TensorflowExampleTransformer
type inflytekit/extras/tensorflow/records.py
which uses the [tf.train.Example] (https://www.tensorflow.org/api_docs/python/tf/train/Example) message, and then serialize, write, and read tf.train.Example messages to and from.tfrecord
files, following the examples in the Tensorflow docs https://www.tensorflow.org/tutorials/load_data/tfrecord - Adds tests for serialisation and deserialisation steps in Transformer
tests/flytekit/unit/extras/tensorflow/test_transformations.py
- Adds test for example workflow using tf.train.Example message.
Tracking Issue
https://github.com/flyteorg/flyte/issues/2571
I havent added this as a plugin since the original issue description was to add this feature similar to format of pytorch transformer type
The unit test failures seem to be caused by tensorflow
not being included.
E ModuleNotFoundError: No module named 'tensorflow'
You should be able to add this to dev-requirements.in
.
I'm excited to see more tensorflow support being contributed!
The unit test failures seem to be caused by
tensorflow
not being included.E ModuleNotFoundError: No module named 'tensorflow'
You should be able to add this to
dev-requirements.in
.I'm excited to see more tensorflow support being contributed!
@dennisobrien thanks, i pushed the changes now. Ive also created a PR https://github.com/flyteorg/flytekit/pull/1242 for keras model support !
Codecov Report
Merging #1240 (23c8bea) into master (f616cd4) will increase coverage by
0.24%
. The diff coverage is73.09%
.
@@ Coverage Diff @@
## master #1240 +/- ##
==========================================
+ Coverage 68.83% 69.08% +0.24%
==========================================
Files 291 295 +4
Lines 26683 26922 +239
Branches 2140 2531 +391
==========================================
+ Hits 18368 18598 +230
- Misses 7817 7829 +12
+ Partials 498 495 -3
Impacted Files | Coverage Δ | |
---|---|---|
flytekit/extras/tensorflow/__init__.py | 0.00% <0.00%> (ø) |
|
flytekit/types/directory/__init__.py | 0.00% <0.00%> (ø) |
|
flytekit/types/file/__init__.py | 17.07% <0.00%> (-0.88%) |
:arrow_down: |
flytekit/extras/tensorflow/record.py | 47.12% <47.12%> (ø) |
|
...tekit/unit/extras/tensorflow/record/test_record.py | 100.00% <100.00%> (ø) |
|
...t/extras/tensorflow/record/test_transformations.py | 100.00% <100.00%> (ø) |
|
flytekit/interfaces/random.py | 20.00% <0.00%> (-5.00%) |
:arrow_down: |
flytekit/configuration/internal.py | 16.43% <0.00%> (-2.03%) |
:arrow_down: |
flytekit/types/directory/types.py | 55.73% <0.00%> (-0.47%) |
:arrow_down: |
flytekit/types/file/file.py | 60.00% <0.00%> (-0.42%) |
:arrow_down: |
... and 9 more |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
@pingsutw pushed requested changes
Writing feedback here for posterity.
Draft Proposal
- Create a
TFRecordFile
type that extendsFlyteFile
to include an additional record typeFlyteFile["tfrecord"]
for serializing/deserializing tfrecords, which handlestf.train.Example
task outputs automatically. - Extend
FlyteDirectory
toTFRecordsDirectory
, which automatically handlesList[tf.data.Example]
outputs by serializing them as TFRecords and stores it as a multi-part blob.
Why not just a type transformer for tf.train.Example
?
Because when we create integrations to other frameworks/libraries, we should facilitate serialization to recommended, stable file formats and deserialize to Python objects that:
- are most useful to the users of the framework (in this case Tensorflow)
- conforms to practical usage patterns.
Since tf.train.Example
is a protobuf message that can't actually be used for model training and needs to be converted into a TFRecord (which is subsequently loaded into a tf.data.Dataset
by the user), supporting tf.train.Example
as a type transformer may lead to confusion, whereas a TFRecordFile
that automatically handles tf.train.Example
outputs (and of course can handle filepaths like regular FlyteFile
types) is clearer in intent:
@task
def produce_record(...) -> TFRecordFile:
return tf.train.Example(...)
Furthermore, the key assumption in this proposal is that not many people actually output a single tf.train.Example
in a task, but rather a collection of Examples.
@task
def produce_records(...) -> TFRecordsDirectory:
return [tf.train.Example(...) for _ in range(100)]
Here, TFRecordsDirectory
would automatically serialize the list of Examples into a FlyteDirectory
of TFRecords
, which can then be passed to a downstream task:
@task
def consume_records(tf_records: TFRecordsDirectory):
return tf.data.TFRecordDataset(os.listdir(tf_records), ...)
Questions
- Do we need a type to handle a single
tf.train.Example
? I'd say no 🙃 but happy to discuss more - Do we actually need
TFRecordFile
to serialize single records as outputs to tasks? - Do we need a type transformer for
tf.data.Dataset
? How much value would something like this provide?
@task
def produce_records(...) -> TFRecordsDirectory:
return [tf.train.Example(...) for _ in range(100)]
@task
def consume_records(
dataset: Annotated[
tf.data.TFRecordDataset,
# configure kwargs to the constructor
# https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
TFRecordDatasetConfig(...)
]
):
... # use the dataset directly
@cosmicBboy, thanks for writing this up! I like the idea behind TFRecordFile
and TFRecordsDirectory
. The directory format might be more useful, but I think we also need to support storing a single tf.Train.Example
or tf.data.Dataset
.
Concerning your questions:
- I agree; we don't need a
tf.Train.Example
TypeTransformer - I think so
- Am I right in assuming that
dataset
here corresponds toTFRecordFile
orTFRecordsDirectory
? If so, besides kwargs, we might also need to let users call methods, e.g., see howget_dataset
fetches the data from aTFRecordDataset
. But I don't think it's possible to streamline this into a type; so a better alternative will be to enable users to provide kwargs and let them apply additional methods or parsers if needed within a task, and I think this could facilitate extraction of the data from aTFRecordDataset
to some extent.
As for the code structure, will this go into flytekit/extras
directory?
But I don't think it's possible to streamline this into a type; so a better alternative will be to enable users to provide kwargs and let them apply additional methods or parsers if needed within a task
Right, I'm thinking for the tf.data.TFRecordDataset
annotated type, we'd just handle the initialization of the object tf.data.TFRecordDataset(filenames, **kwargs)
and then pass that into the task, the user is responsible for other transformations in the function body:
@task
def consume_records(
dataset: Annotated[
tf.data.TFRecordDataset,
# configure kwargs to the constructor
# https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
TFRecordDatasetConfig(...)
]
):
dataset = (
dataset
.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(batch_size * 10)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
What do you think? If this looks good I can update the proposal
As for the code structure, will this go into flytekit/extras directory?
Yep! As long as we follow the same conventions as the pytorch
extra I think we should make this part of the main flytekit api.
@cosmicBboy looks good to me! @ryankarlos please read through the comments.
@cosmicBboy @samhita-alla Im getting this error when running test_native.py
even though i have added the tensorflow.python.data.ops.readers.TFRecordDatasetV2
type in to_python_value
TensorfloTensorflowRecordsTransformer
. Any ideas ?
==================================== ERRORS ====================================
_ ERROR collecting tests/flytekit/unit/extras/tensorflow/records/test_native.py _
tests/flytekit/unit/extras/tensorflow/records/test_native.py:45: in <module>
def wf():
flytekit/core/workflow.py:739: in workflow
return wrapper(_workflow_function)
flytekit/core/workflow.py:734: in wrapper
workflow_instance.compile()
flytekit/core/workflow.py:614: in compile
workflow_outputs = exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs)
flytekit/exceptions/scopes.py:198: in user_entry_point
return wrapped(*args, **kwargs)
tests/flytekit/unit/extras/tensorflow/records/test_native.py:48: in wf
consume(dataset=tf.data.TFRecordDataset(filenames=[generate_tf_record_file()], name="t1"))
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/data/ops/readers.py:458: in __init__
filenames = _create_or_validate_filenames_dataset(filenames, name=name)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/data/ops/readers.py:66: in _create_or_validate_filenames_dataset
filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/profiler/trace.py:183: in wrapped
return func(*args, **kwargs)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:1640: in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py:343: in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py:267: in constant
return _constant_impl(value, dtype, shape, name, verify_shape=False,
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py:279: in _constant_impl
return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py:304: in _constant_eager_impl
t = convert_to_eager_tensor(value, ctx, dtype)
../../opt/miniconda3/envs/flyte/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py:102: in convert_to_eager_tensor
return ops.EagerTensor(value, ctx.device_name, dtype)
E ValueError: Attempt to convert a value (Promise(node:n1.o0)) with an unsupported type (<class 'flytekit.core.promise.Promise'>) to a Tensor.
------------------------------- Captured stderr --------------------------------
{"asctime": "2022-11-04 21:42:05,906", "name": "flytekit", "levelname": "WARNING", "message":
"Unsupported Type <class 'tensorflow.python.data.ops.readers.TFRecordDatasetV2'> found, Flyte will default to use PickleFile as the transport.
Pickle can only be used to send objects between the exact same version of Python, and we strongly recommend to use python type that flyte support."}
2022-11-04 21:42:05.914428: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Within the workflow
function, everything is a Promise
. Can you try doing the work within a task?
def wf():
file = generate_tf_record_file()
tf_record_dataset = get_tf_record_dataset_from_file(file)
consume(dataset=tf_record_dataset)
Test failures on CI are unrelated to tests in this PR

Can you import Annotated
from typing_extensions
? That should fix the failures.
Amazing work, @ryankarlos! A few more comments. Sorry about incrementally reviewing the PR. :/
Amazing work, @ryankarlos! A few more comments. Sorry about incrementally reviewing the PR. :/
Thank you ! No thats fine, you have spotted a lot of my errors which is good !