dataclasses-json icon indicating copy to clipboard operation
dataclasses-json copied to clipboard

Field is list of custom fields

Open nathan5280 opened this issue 4 years ago • 1 comments

I have a dataclass that has a list of an enumeration in it. I have custom field properties to handle this when it is a single instance of the enumeration, but was stumped when I had a list of the enumerations. I solved it by creating a new MM field for the list, but was wondering if anyone had a better solution.

Down near the bottom of the file you will find this, which was the troublesome part.

@dataclass_json
@dataclass
class PlanList:
    # >>>>>>>>>>  This is the troublesome field <<<<<<<<<<<
    allowed_frequencies: List[Frequency] = field(metadata=frequency_list_field)

Entier test setup

from dataclasses import dataclass, field
from enum import Enum
from typing import List

from dataclasses_json import dataclass_json
from marshmallow import fields


class Frequency(Enum):
    annually = 1
    semi_annually = 2
    quarterly = 4
    monthly = 12


class FrequencyField(fields.Field):
    def _serialize(self, value, attr, obj, **kwargs):
        return value.name

    def _deserialize(self, value, attr, data, **kwargs):
        return Frequency[value]

    @staticmethod
    def encoder(frequency: Frequency):
        return frequency.name

    @staticmethod
    def decoder(name: str):
        return Frequency[name]


frequency_field = {
    "dataclasses_json": {
        "encoder": FrequencyField.encoder,
        "decoder": FrequencyField.decoder,
        "mm_field": FrequencyField(),
    }
}


@dataclass_json
@dataclass
class Plan:
    frequency: Frequency = field(metadata=frequency_field)


def test_plan_json():
    obj = Plan(frequency=Frequency.quarterly)
    print()
    print("obj:", obj)
    j = obj.to_json()
    print("json:", j)
    obj_rt = Plan.from_json(j)
    print("obj rt:", obj)
    assert obj == obj_rt


def test_plan_mm():
    obj = Plan(frequency=Frequency.quarterly)
    print()
    print("obj:", obj)
    j = Plan.schema().dumps(obj)
    print("json:", j)
    obj_rt = Plan.schema().loads(j)
    print("obj rt:", obj)
    assert obj == obj_rt


class FrequencyListField(fields.Field):
    def _serialize(self, value, attr, obj, **kwargs):
        return [f.name for f in value]

    def _deserialize(self, value, attr, data, **kwargs):
        return [Frequency[name] for name in value]

    @staticmethod
    def encoder(frequencies: List[Frequency]):
        return [f.name for f in frequencies]

    @staticmethod
    def decoder(value: str):
        # Deal with the case where the MM deserialize has already decoded the Enum.
        if value and isinstance(value[0], Frequency):
            return value
        return [Frequency[name] for name in value]


frequency_list_field = {
    "dataclasses_json": {
        "encoder": FrequencyListField.encoder,
        "decoder": FrequencyListField.decoder,
        "mm_field": FrequencyListField(),
    }
}


@dataclass_json
@dataclass
class PlanList:
    # >>>>>>>>>>  This is the troublesome field <<<<<<<<<<<
    allowed_frequencies: List[Frequency] = field(metadata=frequency_list_field)


def test_plan_list_json():
    obj = PlanList(allowed_frequencies=[Frequency.annually, Frequency.monthly])
    print()
    print("obj:", obj)
    j = obj.to_json()
    print("json:", j)
    obj_rt = PlanList.from_json(j)
    print("obj rt:", obj)
    assert obj == obj_rt


def test_plan_list_mm():
    obj = PlanList(allowed_frequencies=[Frequency.annually, Frequency.monthly])
    print()
    print("obj:", obj)
    j = PlanList.schema().dumps(obj)
    print("json:", j)
    obj_rt = PlanList.schema().loads(j)
    print("obj rt:", obj)
    assert obj == obj_rt

nathan5280 avatar Aug 07 '19 21:08 nathan5280

I'm tend to think your solution is more explicit and therefore at least better for code cleanliness. My use case was to take in a dataclass and produce another dataclass that turned certain fields into lists of that field for the derived class. I am still debugging my solution idea, but your solution has helped me think about this problem so I figure sharing what I have so far could be helpful :). I suspect that something between the two solutions that is a ListFieldsClassFactory would be the best option, but i'm unsure if i'll have time to go down that route.

def dataclass_filter_factory(cls):  # noqa: max-complexity: 13
    '''
    from a dataclass specification, generates a filtering criteria defined by a
    dataclass where a `null` or `None` setting defines no filtering criteria.

    This currently supports types [int, complex, float, enum] and skips others

    Args:
        cls: dataclass to derive filter from

    Example:

        class Pet(Enum):
            CAT = "cat"
            DOG = "dog"

        @dataclass
        class Example(DataClassJsonMixin):
            x: int
            pet: Pet

        ExampleFilter = dataclass_filter_factory(Example)
        example_filter = ExampleFilter.from_json(
            '{"x_upper_bound": 4, "x_lower_bound": null, "pet_match_list": ["cat"]}'
        )
        example_filter.satisfies(Example(x=2, Pet.DOG))  # => False
        example_filter.satisfies(Example(x=2, Pet.CAT))  # => True

    Returns:
        dataclass used to specify filtering criteria with added methods
          `.satisfies(entry)` producing a bool of weather the entry passes the filter
          `.filters(*entries) yeilds all entries that pass the filter criteria

    '''
    ...

    def update_metadata(metadata, field_key, filter_type):
        augment = defaultdict(dict)

        _conditional_key = 'dataclasses_json'
        _encoder_key = 'encoder'
        _decoder_key = 'decoder'
        if _conditional_key in metadata and filter_type in [
            FilterType.EXACT,
            FilterType.REGION,
        ]:
            base_encode_to_str = metadata[_conditional_key][_encoder_key]
            base_decode_from_str = metadata[_conditional_key][_decoder_key]

            list_encode_to_str = lambda x: json.dumps(
                [base_encode_to_str(entry) for entry in x]
            )

            list_decode_from_str = lambda x: [
                base_decode_from_str(entry) for entry in x
            ]

            def null_encoder_to_str(x):
                return x if x is None else list_encode_to_str(x)

            def null_decoder_from_str(x):
                return x if x is None else list_decode_from_str(x)

            augment[_conditional_key][_encoder_key] = null_encoder_to_str
            augment[_conditional_key][_decoder_key] = null_decoder_from_str

        return {
            **metadata,
            **augment,
            _FILTER_SPEC_KEY: {
                _FILTER_SPEC_FIELD_KEY: field_key,
                _FILTER_SPEC_FILTER_TYPE: filter_type,
            },
        }
        
        ...
        
        field_specs = []
        for f in dataclasses.fields(cls):
            for filter_type in get_filter_types(f.type):
                fname = rename_field(f.name, filter_type)
                ftype = retype_field(f.type)
                if ftype is None:
                    continue
                fmeta = dataclasses.field(
                    default=None, metadata=update_metadata(f.metadata, f.name, filter_type)
                )
                fspec = (fname, ftype, fmeta)
                field_specs.append(fspec)
                
        name = f'{cls.__name__}Filter'
        ClsFilter = dataclasses.make_dataclass(
            name, field_specs, bases=(DataClassJsonMixin,), frozen=True
        )

probinso avatar Aug 22 '22 17:08 probinso