[BUG] Fix compilation error for partial tasks with lists in map_task
Tracking issue
Reference Issue
Why are the changes needed?
Currently, Flyte does not support partial tasks with lists as inputs, resulting in errors when using map_task with such inputs. When lists are passed as default arguments in workflows, Flyte works correctly locally but throws an ambiguous error remotely, indicating that input arrays have different lengths. This PR addresses the issue by raising a clear error during the compilation phase when partial tasks are used with lists.
What changes were proposed in this pull request?
• Modified the behavior of Flyte to raise a ValueError at compilation time when partial tasks are used with lists as inputs. • The error message will indicate that the input arrays have different lengths, providing more clarity and preventing ambiguous errors at runtime. • Updated the workflow behavior to handle default arguments in map_task appropriately.
How was this patch tested?
The patch was tested by test script, which included multiple test cases to check for correct behavior when using map_task with partial tasks and lists. The tests include:
• Mismatched input lengths: Ensures a ValueError is raised when input arrays have different lengths.
• Matching input lengths: Verifies that the task executes successfully when input arrays have the same length.
• Default arguments in workflows: Confirms that default arguments in workflows work as expected without errors.
The tests passed successfully, confirming that the issue was resolved.
Check all the applicable boxes
- [x] I updated the documentation accordingly.
- [x] All new and existing tests passed.
- [x] All commits are signed-off.
Just noticed that you are super active, will communicate with other maintainers to collaborate with you.
Just noticed that you are super active, will communicate with other maintainers to collaborate with you.
Thank you! I’m looking forward to collaborating with the team and learning more
sorry this is in the legacy_map_task file... what does the issue look like/is the issue present in the newer map task?
Hi @wild-endeavor,
I tested the issue in both versions (legacy and newer map tasks) and found that the error is similar in both cases, indicating that map tasks do not fully support partial tasks with lists as inputs. However, the error messages differ slightly:
Workflow used-
@workflow
def wf(my_list: list[int] = [1, 2, 3]) -> list[int]:
my_vals = [1, 2]
partial_task = functools.partial(mult_sum, my_list=my_list)
return map_task(partial_task)(my_val=my_vals)
• Legacy map_task: When running the workflow with the legacy map task, the error raised is less descriptive, showing something like:
input arrays have different lengths: expecting '3' found '2'
This error tells us the input arrays have mismatched lengths but doesn’t explicitly mention that map tasks require the arrays to be of the same length, which could make debugging harder.
• Newer map_task: When running with the newer map_task, there also error is raised, showing something like:
all map task input lists must be the same length
So, the issue is present in both versions.
Regarding the solution, I mentioned in the parent issue, that the goal is to raise an error at compile-time and runtime if a partial task with lists as inputs is used in a multi-input map task. I am currently working on addressing this, and the team will consider how to better handle partial inputs later.
Let me know if you have any further suggestions or if you’d like to discuss this more.