[BUG] Union types fail for e.g. two different dataclasses
Describe the bug
from typing import Union
from dataclasses import dataclass
from flytekit import task, workflow
@dataclass
class A:
a: int
@dataclass
class B:
b: int
@task
def foo(inp: Union[A, B]):
...
@workflow
def wf():
foo(inp=B(b=1))
if __name__ == "__main__":
wf()
The workflow works when executing locally with python wf.py but fails to register with flyteadmin:
flytekit.exceptions.user.FlyteInvalidInputException: USER:BadInputToAPI: error=None, cause=<_InactiveRpcError of RPC that terminated with:
status = StatusCode.INVALID_ARGUMENT
details = "failed to compile workflow for [resource_type:WORKFLOW project:”…” domain:"development" name:”…f” version:”…”] with err failed to compile workflow with err Collected Errors: 1
Error 0: Code: MismatchingTypes, Node Id: n0, Description: Variable [inp] (type [simple:STRUCT]) doesn't match expected type [union_type:{variants:{simple:STRUCT metadata:{fields:{key:"additionalProperties" value:{bool_value:false}} fields:{key:"properties" value:{struct_value:{fields:{key:"a" value:{struct_value:{fields:{key:"type" value:{string_value:"integer"}}}}}}}} fields:{key:"required" value:{list_value:{values:{string_value:"a"}}}} fields:{key:"title" value:{string_value:"A"}} fields:{key:"type" value:{string_value:"object"}}} structure:{tag:"Object-Dataclass-Transformer"}} variants:{simple:STRUCT metadata:{fields:{key:"additionalProperties" value:{bool_value:false}} fields:{key:"properties" value:{struct_value:{fields:{key:"b" value:{struct_value:{fields:{key:"type" value:{string_value:"integer"}}}}}}}} fields:{key:"required" value:{list_value:{values:{string_value:"b"}}}} fields:{key:"title" value:{string_value:"B"}} fields:{key:"type" value:{string_value:"object"}}} structure:{tag:"Object-Dataclass-Transformer"}}}].
Expected behavior
As a python developer using Flyte, I would expect this workflow to work.
Additional context to reproduce
The root cause for this error is the following:
When validating the workflow in the backend, here, the so-called unionTypeChecker checks whether one of the union variants (here A or B) can unambiguously be chosen:
func (t unionTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool {
...
// Matches iff we can unambiguously select a variant
foundOne := false
for _, x := range unionType.GetVariants() {
if AreTypesCastable(upstreamType, x) {
if foundOne {
return false
}
foundOne = true
}
}
return foundOne
}
For our example above, AreTypesCastable(upstreamType, x) yields true for both union variants A and B which causes the unionTypeChecker to fail.
The reason that for both A and B the check AreTypesCastable(upstreamType, x) results in true is the following:
Here, the so-called trivialChecker which is called for both union variants A and B compares whether the passed input B matches the respective variant:
func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool {
...
if GetTagForType(upstreamType) != "" && GetTagForType(t.literalType) != GetTagForType(upstreamType) { # There are no tags that could solve the ambiguity in our example
return false
}
// Ignore metadata when comparing types.
upstreamTypeCopy := *upstreamType
downstreamTypeCopy := *t.literalType
upstreamTypeCopy.Structure = &flyte.TypeStructure{}
downstreamTypeCopy.Structure = &flyte.TypeStructure{}
upstreamTypeCopy.Metadata = &structpb.Struct{}
downstreamTypeCopy.Metadata = &structpb.Struct{}
upstreamTypeCopy.Annotation = &flyte.TypeAnnotation{}
downstreamTypeCopy.Annotation = &flyte.TypeAnnotation{}
return upstreamTypeCopy.String() == downstreamTypeCopy.String()
}
Since there are no tags that solve the ambiguity and all metadata is ignored, the final string comparison for both union variants A and B is always "simple:STRUCT metadata:{} annotation:{} structure:{}" == "simple:STRUCT metadata:{} annotation:{} structure:{}" meaning that we cannot determine whether the union variant A or B is the correct match for the passed value B.
Screenshots
No response
Are you sure this issue hasn't been raised already?
- [X] Yes
Have you read the Code of Conduct?
- [X] Yes
The issue arises because Flyte's backend cannot unambiguously select a variant from the union type when validating the workflow. This is due to the unionTypeChecker failing to differentiate between the union variants A and B as both are treated as simple:STRUCT.
Would you like more detailed steps on how to address this issue or further explanation on the underlying mechanics?
Sources:
This is an AI-generated response and your feedback is appreciated! Please leave a 👍 if this is helpful and 👎 if it is not.
Mention @runllm to ask follow-up questions. See runllm.com for more!
@RunLLM Can I work on this issue ?
The issue arises because Flyte's backend cannot unambiguously select a variant from the union type when validating the workflow. This is due to the unionTypeChecker failing to differentiate between the union variants A and B as both are treated as simple:STRUCT.
Would you like more detailed steps on how to address this issue or further explanation on the underlying mechanics?
Sources:
This is an AI-generated response and your feedback is appreciated! Please leave a 👍 if this is helpful and 👎 if it is not.
Mention @runllm to ask follow-up questions. See runllm.com for more!
We're in the process of revisiting the dataclass transformer and will include this in the process.
A good point raised during today's contributor meeting is that simply adding the tags to the struct message is not enough as that would break inheritance.
@fg91 Eduardo, Kevin, Thomas, Yee and I just discussed this 10 minutes ago. We can use the dataclass's path to check if 2 dataclass is the same in the Union Type. (Please make it backward compatibility.)
{"dataclass_path": f"{python_type.__module__}.{python_type.__qualname__}}
Thank you so much.
@fg91 Eduardo, Kevin, Thomas, Yee and I just discussed this 10 minutes ago. We can use the dataclass's path to check if 2 dataclass is the same in the Union Type. (Please make it backward compatibility.)
{"dataclass_path": f"{python_type.__module__}.{python_type.__qualname__}}Thank you so much.
This doesn't account for inheritance, correct? Or am I overlooking something?
@fg91 Eduardo, Kevin, Thomas, Yee and I just discussed this 10 minutes ago. We can use the dataclass's path to check if 2 dataclass is the same in the Union Type. (Please make it backward compatibility.)
{"dataclass_path": f"{python_type.__module__}.{python_type.__qualname__}}Thank you so much.
This doesn't account for inheritance, correct? Or am I overlooking something?
@fg91 We can use class.mro() and put it to tags.
@fg91 Eduardo, Kevin, Thomas, Yee and I just discussed this 10 minutes ago. We can use the dataclass's path to check if 2 dataclass is the same in the Union Type. (Please make it backward compatibility.)
{"dataclass_path": f"{python_type.__module__}.{python_type.__qualname__}}Thank you so much.
This doesn't account for inheritance, correct? Or am I overlooking something?
@fg91 We can use
class.mro()and put it to tags.
The hard part is that now tag is just a string, but not a list, so we might have to use a separator to seperate all tags.
This doesn't account for inheritance, correct? Or am I overlooking something?
@fg91 We can use
class.mro()and put it to tags.
from typing import Union
from dataclasses import dataclass
from flytekit import task, workflow
@dataclass
class A:
a: int
@dataclass
class B:
b: int
@dataclass
class C(B):
c: int
@task
def foo(inp: Union[A, B]):
print(inp)
@workflow
def wf():
foo(inp=C(b=1, c=1))
if __name__ == "__main__":
print(C.mro())
wf()
Output:
[<class '__main__.C'>, <class '__main__.B'>, <class 'object'>]
B(b=1)
So yes, I agree, if we serialized the output of .mro into the literal and sent this information to propeller, it could discern which of the union variants fits.
I'm thinking whether one further complication is that depending on whether a class is defined in the main module or imported, it might be called <class '__main__.B'> or <class 'some.module.path.B'>, right?
This problem exists regardless of whether we want to support inheritance or not, also when we just write the class itself into the tag.
At the top of my head, I don't have a solution for this unfortunately, do you @Future-Outlier?
In my opinion it is a must for us to solve the issue that the union checker can't distinguish between any two types that use protobuf struct as literal but I think it would be fair to exclude inheritance - if it makes our lives easier - as long as there is a comprehensible error returned by flyteadmin when registering.
This doesn't account for inheritance, correct? Or am I overlooking something?
@fg91 We can use
class.mro()and put it to tags.from typing import Union from dataclasses import dataclass from flytekit import task, workflow @dataclass class A: a: int @dataclass class B: b: int @dataclass class C(B): c: int @task def foo(inp: Union[A, B]): print(inp) @workflow def wf(): foo(inp=C(b=1, c=1)) if __name__ == "__main__": print(C.mro()) wf()Output:
[<class '__main__.C'>, <class '__main__.B'>, <class 'object'>] B(b=1)So yes, I agree, if we serialized the output of
.mrointo the literal and sent this information to propeller, it could discern which of the union variants fits.I'm thinking whether one further complication is that depending on whether a class is defined in the main module or imported, it might be called
<class '__main__.B'>or<class 'some.module.path.B'>, right?This problem exists regardless of whether we want to support inheritance or not, also when we just write the class itself into the tag.
At the top of my head, I don't have a solution for this unfortunately, do you @Future-Outlier?
In my opinion it is a must for us to solve the issue that the union checker can't distinguish between any two types that use protobuf struct as literal but I think it would be fair to exclude inheritance - if it makes our lives easier - as long as there is a comprehensible error returned by flyteadmin when registering.
Thank you for the reply, I don't have a solution about this yet.
I'm thinking whether one further complication is that depending on whether a class is defined in the main module or imported, it might be called
<class '__main__.B'>or<class 'some.module.path.B'>, right? Can you explain more or give me an example what this problem will cause your inconvenience? I am still trying to come up with an example.
I will think about the alternatives. BTW, I will go to the military service in the next 2 weeks and I will try to come back discuss with you after that.
Isn't there a json schema that's published alongside dataclasses that we can rely on? can we push compat checking to that level?
That is if you're checking to see if type X is compatible with type Y:
- if X has schema and Y does not (compatible)
- if X does not have schema and Y does (not compatible)
- if both have schema
- use json schema to see if X's schema is compatible with Y's.
- if Y is a Union, iterate through each variant, doing the same check, if more than one match found then error
I don't like the idea of propeller's compiler checking python mro using string matches. That sounds incredibly error prone and replete with edge cases (changing file names, changing type name, root import location, probably missing more).
I don't like the idea of propeller's compiler checking python mro using string matches. That sounds incredibly error prone and replete with edge cases (changing file names, changing type name, root import location, probably missing more).
Agree 👍
goal by @wild-endeavor
# upstream
@dataclass
class A:
a: int
# downstream
@dataclass
class A:
a: int
b: Optional[int]
This case should work.
@Future-Outlier @wild-endeavor Agree with the the example with one small addition: I would be really nice if the solution wasn't tailor made for the dataclass type transformer but generalizable to every transformer that uses protobuf struct as transport. We defined internal type transformers that use structs and hence suffer from the same problem and it would be nice to be able to apply the same fix there 🙇
related to https://github.com/flyteorg/flyte/issues/5318
Since this fix by @wild-endeavor in flytekit, the minimal reproducing example from the issue description doesn't fail anymore. I don't understand why because I confirmed with a debugger that the CastsFrom of the unionTypeChecker still considers both variants a match here.
However, this example still doesn't work:
@dataclass_json
@dataclass
class A:
a: int
@dataclass_json
@dataclass
class B:
b: int
@task
def bar() -> A:
return A(a=1)
@task
def foo(inp: Union[A, B]):
print(inp)
@workflow
def wf():
v = bar()
foo(inp=v)
Since this fix by @wild-endeavor in flytekit, the minimal reproducing example from the issue description doesn't fail anymore. I don't understand why because I confirmed with a debugger that the
CastsFromof theunionTypeCheckerstill considers both variants a match here.However, this example still doesn't work:
@dataclass_json @dataclass class A: a: int @dataclass_json @dataclass class B: b: int @task def bar() -> A: return A(a=1) @task def foo(inp: Union[A, B]): print(inp) @workflow def wf(): v = bar() foo(inp=v)
My friend @mao3267 is investigating this, and I'll support him to push this happen.