Stop requiring users to import `dataclasses_json` or `DataClassJSONMixin` for dataclass
Tracking issue
https://github.com/flyteorg/flyte/issues/4486
Why are the changes needed?
For a better user experience.
What changes were proposed in this pull request?
-
use
mashumaro>=3.11, so that we can useJSONEncoderandJSONDecoder -
~~change
python_val.to_json()toJSONEncoder(python_type).encode(python_val)~~ -
~~change
expected_python_type.from_json(json_str)toJSONDecoder(expected_python_type).decode(json_str)~~ -
~~change
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))toreturn dataclasses.make_dataclass(schema_name, attribute_list), since we don't needto_jsonmethod andfrom_jsonmethod anymore.~~ -
add tests
-
fix mypy errors
-
change type annotations
-
~~remove
flytekit-doltfrom CI test, since it doesn't work now and needs to be implemented a new version.~~ -
use
JSONEncoderandJSONDecoderto convertdataclasstojson strand convertjson strtodataclasswhen the user didn't usedataclasses_jsonandDataClassJSONMixin. -
add an encoder registry and a decoder registry to cache
JSONEncoderandJSONDecoderwhen usingList[dataclass] -
add a benchmark test in real case scenario by dynamic workflow and return
List[dataclass]
How was this patch tested?
- unit tests
- local and remotely with only
dataclassdecorator - local and remotely with dataclass inherits from
DataClassJSONMixin. (for backward compatible) - local and remotely with
dataclass_jsondecorator. (for backward compatible)
Note: you can use futureoutlier/dataclass:0321 this image to test it.
Setup process
python dataclass_example.py
pyflyte run --remote --image localhost:30000/dataclass:0951 dataclass_example.py dataclass_wf --x 10 --y 20
import os
import tempfile
from dataclasses import dataclass
from typing import Tuple, List, Optional
import pandas as pd
from flytekit import task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset
# from mashumaro.mixins.json import DataClassJSONMixin
@dataclass
class Datum:
x: int
y: str
z: dict[int, int]
w: List[int] = None
@task
def stringify(s: int) -> Datum:
"""
A dataclass return will be treated as a single complex JSON return.
"""
return Datum(x=s, y=str(s), z={s: str(s)}, w=[s,s,s,s])
@task
def add(x: Datum, y: Datum) -> Datum:
"""
Flytekit automatically converts the provided JSON into a data class.
If the structures don't match, it triggers a runtime failure.
"""
x.z.update(y.z)
return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z, w=x.w + y.w)
@dataclass
class FlyteTypes:
dataframe: StructuredDataset
file: FlyteFile
directory: FlyteDirectory
@task
def upload_data() -> FlyteTypes:
"""
Flytekit will upload FlyteFile, FlyteDirectory and StructuredDataset to the blob store,
such as GCP or S3.
"""
# 1. StructuredDataset
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
# 2. FlyteDirectory
temp_dir = tempfile.mkdtemp(prefix="flyte-")
df.to_parquet(temp_dir + "/df.parquet")
# 3. FlyteFile
file_path = tempfile.NamedTemporaryFile(delete=False)
file_path.write(b"Hello, World!")
fs = FlyteTypes(
dataframe=StructuredDataset(dataframe=df),
file=FlyteFile(file_path.name),
directory=FlyteDirectory(temp_dir),
)
return fs
@task
def download_data(res: FlyteTypes):
assert pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}).equals(res.dataframe.open(pd.DataFrame).all())
f = open(res.file, "r")
assert f.read() == "Hello, World!"
assert os.listdir(res.directory) == ["df.parquet"]
@workflow
def dataclass_wf(x: int, y: int) -> Tuple[Datum, FlyteTypes]:
o1 = add(x=stringify(s=x), y=stringify(s=y))
o2 = upload_data()
download_data(res=o2)
return o1, o2
if __name__ == "__main__":
print(dataclass_wf(x=10, y=20))
FROM python:3.9-slim-buster
USER root
WORKDIR /root
ENV PYTHONPATH /root
RUN apt-get update && apt-get install build-essential -y
RUN apt-get install git -y
RUN pip install -U git+https://github.com/flyteorg/flytekit.git@30223e45c6b773cb25846f5031f92e4f1f783c33
RUN pip install pandas -U
Screenshots
local execution (with only dataclass decorator)
local execution (with DataClassJSONMixin)
remote execution (with only dataclass decorator)
remote execution (with DataClassJSONMixin)
remote execution (with dataclass_json)
Check all the applicable boxes
- [ ] I updated the documentation accordingly.
- [x] All new and existing tests passed.
- [x] All commits are signed-off.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 83.49%. Comparing base (
55f0b19) to head (0a33f53).
Additional details and impacted files
@@ Coverage Diff @@
## master #2279 +/- ##
==========================================
+ Coverage 83.46% 83.49% +0.03%
==========================================
Files 324 324
Lines 24754 24757 +3
Branches 3521 3519 -2
==========================================
+ Hits 20662 20672 +10
+ Misses 3460 3455 -5
+ Partials 632 630 -2
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
cc @thomasjpfan Please take a look, thank you!
I think we can remove all DataClassJSONMixin class which has also @dataclass decorator.
Should we do this in this PR or create another?
I think we can remove all
DataClassJSONMixinclass which has also@dataclassdecorator. Should we do this in this PR or create another?![]()
I think it will be ok to open another housekeeping PR to do this since we need to also update relevant tests. And the change will be super huge and time-consuming for review.
Test Performance
I use the code above to test the performance compared using this PR (only need dataclass decorator),
(use mashumaro DataClassJSONMixin) and using @dataclass_json.
def test_speed():
import time
start_time = time.time()
for _ in range(1000):
dataclass_wf(x=10, y=20)
end_time = time.time()
print(f"Time taken: {end_time - start_time}")
I think the main reason for this performance gap is that we need to create a JsonEncoder or a JsonDecoder to serialize and deserialize our dataclasses.
However, I still think this PR is worth merging. If we compare them in only 1 iteration, the gap is very small and won't pose a performance issue
I think the main reason for this performance gap is that we need to create a JsonEncoder or a JsonDecoder to serialize and deserialize our dataclasses.
I’m sure it is. Creating decoders and encoders is not a cheap operation. I would recommend to use a registry dataclass_type -> decoder(encoder) if this is an issue for you.
I think the main reason for this performance gap is that we need to create a JsonEncoder or a JsonDecoder to serialize and deserialize our dataclasses.
I’m sure it is. Creating decoders and encoders is not a cheap operation. I would recommend to use a registry dataclass_type -> decoder(encoder) if this is an issue for you.
I love your idea, but in the real case, we will need to repeat register dataclass type encoder and decoder in every pod task node.
So maybe this solution will be better, which means still faster.
However, I really think that it is worth to use JsonEncoder and JsonDecoder.
The main performance issue is not because of serialization, or in other words, we can optimize other places.
@thomasjpfan , @Fatal1ty I use this example to test the speed of workflow, it seems that the performance is close : )
@task
def create_dataclasses() -> List[Datum]:
return [Datum(x=1, y="1", z={1: 1}, w=[1,1,1,1])]
@task
def concat_dataclasses(x: List[Datum], y: List[Datum]) -> List[Datum]:
return x + y
@dynamic
def dynamic_wf() -> List[Datum]:
all_dataclasses = [Datum(x=1, y="1", z={1: 1}, w=[1,1,1,1])]
for _ in range(300):
data = create_dataclasses()
all_dataclasses = concat_dataclasses(x=all_dataclasses, y=data)
return all_dataclasses
@workflow
def benchmark_workflow() -> List[Datum]:
return dynamic_wf()
if __name__ == "__main__":
import time
start_time = time.time()
benchmark_workflow()
end_time = time.time()
print(f"Time taken: {end_time - start_time}")
I think the main reason for this performance gap is that we need to create a JsonEncoder or a JsonDecoder to serialize and deserialize our dataclasses.
I’m sure it is. Creating decoders and encoders is not a cheap operation. I would recommend to use a registry dataclass_type -> decoder(encoder) if this is an issue for you.
I've tested it on your advice, and it really reduces the time drastically.
is this backwards compatible? serialize with an old flytekit release, with old user code. then deserialize with new flytekit and new user code.
I think not requiring
dataclasses_jsonorDataClassJSONMixinfor many use cases is already a net improvement.LGTM
Added comments, thank you so much