pydantic-avro icon indicating copy to clipboard operation
pydantic-avro copied to clipboard

Add Support for Pydantic Tagged Union

Open Natesuri opened this issue 1 year ago • 4 comments

Pydantic has a new feature for tagged unions.

https://docs.pydantic.dev/latest/usage/types/#discriminated-unions-aka-tagged-unions

Example usage of Tagged Union

Class Bar(BaseModel):
    foo: Union[FooA, FooB] = Field(..., discriminator="foo_discriminator")
    
Class FooA(BaseModel):
  foo_discriminator: Literal["A"]
  
Class FooB(BaseModel):
  foo_discriminator: Literal["B"]

In the method get_type this is the value that is passed in for the above Tagged Union.

{'title': 'Foo', 'discriminator': {'propertyName': 'FooDiscriminator', 'mapping': {'A': '#/definitions/FooA', 'B': '#/definitions/FooB'}}, 'oneOf': [{'$ref': '#/definitions/FooA'}, {'$ref': '#/definitions/FooB'}]}

If we can at least support the avro schema generation of tagged union to match a standard union, it would be very helpful

Natesuri avatar Jun 15 '23 18:06 Natesuri

I recently ran into a use-case for this as well. It'd be a great addition.

It looks like it can be supported as a standard union (haven't looked into tagged) with the following one-line change to get_type method. Update

u = value.get("anyOf")

to

u = value.get("anyOf", value.get("oneOf"))

OffByOnee avatar Aug 24 '23 14:08 OffByOnee

The suggested change above was generating an invalid Avro schema. The following patch provides support for converting a JSON discriminated union to an Avro union

diff --git a/src/pydantic_avro/base.py b/src/pydantic_avro/base.py
index 6e16a30..60912c9 100644
--- a/src/pydantic_avro/base.py
+++ b/src/pydantic_avro/base.py
@@ -49,6 +49,7 @@ class AvroBase(BaseModel):
             r = value.get("$ref")
             a = value.get("additionalProperties")
             u = value.get("anyOf")
+            du = value.get("oneOf")
             minimum = value.get("minimum")
             maximum = value.get("maximum")
             avro_type_dict: Dict[str, Any] = {}
@@ -62,6 +63,10 @@ class AvroBase(BaseModel):
                 avro_type_dict["type"] = []
                 for union_element in u:
                     avro_type_dict["type"].append(get_type(union_element)["type"])
+            elif du is not None:
+                avro_type_dict["type"] = []
+                for union_element in du:
+                    avro_type_dict["type"].append(get_type(union_element)["type"])
             elif r is not None:
                 class_name = r.replace(f"#/{DEFS_NAME}/", "")
                 if class_name in classes_seen:
@@ -105,6 +110,12 @@ class AvroBase(BaseModel):
                 ):
                     items = tn["type"]["items"]
                     tn = {"type": "array", "items": items}
+                # If items in array are a union
+                if (
+                    isinstance(tn, dict)
+                    and isinstance(tn.get("type", {}), list)
+                ):
+                    tn = {"type": "array", "items": tn["type"]}
                 avro_type_dict["type"] = {"type": "array", "items": tn}
             elif t == "string" and f == "date-time":
                 avro_type_dict["type"] = {

OffByOnee avatar Aug 24 '23 19:08 OffByOnee

@timvancann @ffinfo What do you think? Would this be a welcome change?

OffByOnee avatar Aug 24 '23 19:08 OffByOnee

@OffByOnee I don't see any reason why we can't support this. If you create a PR with this change and some tests I can review it properly :).

timvancann avatar Aug 25 '23 11:08 timvancann