flytekit
flytekit copied to clipboard
Resolve map task issues for batchable list
TL;DR
Before this PR, if a list is batchable (utilizing FlytePickleTransformer), it is batched into a single item by default and stored in one pickle file. This results in the literal list containing only one literal. However, to successfully execute the map task with the current implementation, the length of the literal list should match that of the original list.
Since this portion of the map task logic resides within flyteplugin, and users may rarely upgrade propeller, it's more practical to implement this change within flytekit.
The proposed solution involves appending placeholders (none literals) to the literal list until its length is equivalent to the original list.
When converting to a Python value, the conversion can be halted if the literal is of the 'none' type. Given that the list is batchable, all genuine literals have to be "pickle literals" and cannot be none literals.
The motivation to maintain the batch mechanism is that it enables the transformation of large, batchable lists to be performed incredibly quickly. It eliminates the need for uploading numerous small pickle files to S3 (one for each), and instead, allows for the uploading of a single large file. While the asynchronous method also saves time, it is considerably slower compared to the batch method.
Checks
The map task ran successfully.
import typing
from flytekit.types.pickle.pickle import BatchSize
from flytekit import map_task, task, workflow
@task
def generate_list() -> typing.Annotated[typing.List[typing.Any],BatchSize(4)]:
return [1, 2, 3, 4, 5]
@task
def my_mappable_task(a: typing.Any) -> typing.Optional[str]:
print(a)
return str(a)
@workflow
def my_wf() -> typing.List[typing.Optional[str]]:
data = generate_list()
return map_task(my_mappable_task, concurrency=10)(a=data)
if __name__ == "__main__":
my_wf()
Large, batchable lists transform extremely quickly.
from typing import List, Dict, Any
from flytekit import task, Resources, workflow
@task(
limits=Resources(mem="4Gi",cpu="1"),
)
def fetch_data() -> List[Any]:
return [{"a": {0: "foo"}}] * 10000
@task(
limits=Resources(mem="4Gi",cpu="1"),
)
def print_length(x: List[Any]):
print(len(x))
@workflow
def my_wf():
x = fetch_data()
print_length(x=x)
if __name__ == "__main__":
my_wf()