dataclasses-json
dataclasses-json copied to clipboard
Field is list of custom fields
I have a dataclass that has a list of an enumeration in it. I have custom field properties to handle this when it is a single instance of the enumeration, but was stumped when I had a list of the enumerations. I solved it by creating a new MM field for the list, but was wondering if anyone had a better solution.
Down near the bottom of the file you will find this, which was the troublesome part.
@dataclass_json
@dataclass
class PlanList:
# >>>>>>>>>> This is the troublesome field <<<<<<<<<<<
allowed_frequencies: List[Frequency] = field(metadata=frequency_list_field)
Entier test setup
from dataclasses import dataclass, field
from enum import Enum
from typing import List
from dataclasses_json import dataclass_json
from marshmallow import fields
class Frequency(Enum):
annually = 1
semi_annually = 2
quarterly = 4
monthly = 12
class FrequencyField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
return value.name
def _deserialize(self, value, attr, data, **kwargs):
return Frequency[value]
@staticmethod
def encoder(frequency: Frequency):
return frequency.name
@staticmethod
def decoder(name: str):
return Frequency[name]
frequency_field = {
"dataclasses_json": {
"encoder": FrequencyField.encoder,
"decoder": FrequencyField.decoder,
"mm_field": FrequencyField(),
}
}
@dataclass_json
@dataclass
class Plan:
frequency: Frequency = field(metadata=frequency_field)
def test_plan_json():
obj = Plan(frequency=Frequency.quarterly)
print()
print("obj:", obj)
j = obj.to_json()
print("json:", j)
obj_rt = Plan.from_json(j)
print("obj rt:", obj)
assert obj == obj_rt
def test_plan_mm():
obj = Plan(frequency=Frequency.quarterly)
print()
print("obj:", obj)
j = Plan.schema().dumps(obj)
print("json:", j)
obj_rt = Plan.schema().loads(j)
print("obj rt:", obj)
assert obj == obj_rt
class FrequencyListField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
return [f.name for f in value]
def _deserialize(self, value, attr, data, **kwargs):
return [Frequency[name] for name in value]
@staticmethod
def encoder(frequencies: List[Frequency]):
return [f.name for f in frequencies]
@staticmethod
def decoder(value: str):
# Deal with the case where the MM deserialize has already decoded the Enum.
if value and isinstance(value[0], Frequency):
return value
return [Frequency[name] for name in value]
frequency_list_field = {
"dataclasses_json": {
"encoder": FrequencyListField.encoder,
"decoder": FrequencyListField.decoder,
"mm_field": FrequencyListField(),
}
}
@dataclass_json
@dataclass
class PlanList:
# >>>>>>>>>> This is the troublesome field <<<<<<<<<<<
allowed_frequencies: List[Frequency] = field(metadata=frequency_list_field)
def test_plan_list_json():
obj = PlanList(allowed_frequencies=[Frequency.annually, Frequency.monthly])
print()
print("obj:", obj)
j = obj.to_json()
print("json:", j)
obj_rt = PlanList.from_json(j)
print("obj rt:", obj)
assert obj == obj_rt
def test_plan_list_mm():
obj = PlanList(allowed_frequencies=[Frequency.annually, Frequency.monthly])
print()
print("obj:", obj)
j = PlanList.schema().dumps(obj)
print("json:", j)
obj_rt = PlanList.schema().loads(j)
print("obj rt:", obj)
assert obj == obj_rt
I'm tend to think your solution is more explicit and therefore at least better for code cleanliness. My use case was to take in a dataclass
and produce another dataclass
that turned certain fields into lists of that field for the derived class. I am still debugging my solution idea, but your solution has helped me think about this problem so I figure sharing what I have so far could be helpful :). I suspect that something between the two solutions that is a ListFieldsClassFactory
would be the best option, but i'm unsure if i'll have time to go down that route.
def dataclass_filter_factory(cls): # noqa: max-complexity: 13
'''
from a dataclass specification, generates a filtering criteria defined by a
dataclass where a `null` or `None` setting defines no filtering criteria.
This currently supports types [int, complex, float, enum] and skips others
Args:
cls: dataclass to derive filter from
Example:
class Pet(Enum):
CAT = "cat"
DOG = "dog"
@dataclass
class Example(DataClassJsonMixin):
x: int
pet: Pet
ExampleFilter = dataclass_filter_factory(Example)
example_filter = ExampleFilter.from_json(
'{"x_upper_bound": 4, "x_lower_bound": null, "pet_match_list": ["cat"]}'
)
example_filter.satisfies(Example(x=2, Pet.DOG)) # => False
example_filter.satisfies(Example(x=2, Pet.CAT)) # => True
Returns:
dataclass used to specify filtering criteria with added methods
`.satisfies(entry)` producing a bool of weather the entry passes the filter
`.filters(*entries) yeilds all entries that pass the filter criteria
'''
...
def update_metadata(metadata, field_key, filter_type):
augment = defaultdict(dict)
_conditional_key = 'dataclasses_json'
_encoder_key = 'encoder'
_decoder_key = 'decoder'
if _conditional_key in metadata and filter_type in [
FilterType.EXACT,
FilterType.REGION,
]:
base_encode_to_str = metadata[_conditional_key][_encoder_key]
base_decode_from_str = metadata[_conditional_key][_decoder_key]
list_encode_to_str = lambda x: json.dumps(
[base_encode_to_str(entry) for entry in x]
)
list_decode_from_str = lambda x: [
base_decode_from_str(entry) for entry in x
]
def null_encoder_to_str(x):
return x if x is None else list_encode_to_str(x)
def null_decoder_from_str(x):
return x if x is None else list_decode_from_str(x)
augment[_conditional_key][_encoder_key] = null_encoder_to_str
augment[_conditional_key][_decoder_key] = null_decoder_from_str
return {
**metadata,
**augment,
_FILTER_SPEC_KEY: {
_FILTER_SPEC_FIELD_KEY: field_key,
_FILTER_SPEC_FILTER_TYPE: filter_type,
},
}
...
field_specs = []
for f in dataclasses.fields(cls):
for filter_type in get_filter_types(f.type):
fname = rename_field(f.name, filter_type)
ftype = retype_field(f.type)
if ftype is None:
continue
fmeta = dataclasses.field(
default=None, metadata=update_metadata(f.metadata, f.name, filter_type)
)
fspec = (fname, ftype, fmeta)
field_specs.append(fspec)
name = f'{cls.__name__}Filter'
ClsFilter = dataclasses.make_dataclass(
name, field_specs, bases=(DataClassJsonMixin,), frozen=True
)