aioinject icon indicating copy to clipboard operation
aioinject copied to clipboard

Dataclass with default factory not supported.

Open nrbnlulu opened this issue 2 years ago • 5 comments

example

import dataclasses

import aioinject


@dataclasses.dataclass(slots=True)
class FooBase[T]:
	bar: dict[T] = dataclasses.field(default_factory=dict)

	def baz(self):
		raise NotImplementedError()


class FooImpl(FooBase[int]):
	def baz(self):
		pass


container = aioinject.Container()
container.register(aioinject.Singleton(FooImpl))

with container.sync_context() as ctx:
	foo_impl = ctx.resolve(FooImpl)

To get around this what you could do something like

import dataclasses
from typing import TYPE_CHECKING

import aioinject


@dataclasses.dataclass(slots=True)
class FooBase[T]:
	if TYPE_CHECKING:
		bar: dict[T] = dataclasses.field(default_factory=dict)
	else:
		bar: dict = dataclasses.field(default_factory=dict)

	def baz(self):
		raise NotImplementedError()


class FooImpl(FooBase[int]):
	def baz(self):
		pass


container = aioinject.Container()
container.register(
	aioinject.Transient(dict),
    aioinject.Singleton(FooImpl)
    )

with container.sync_context() as ctx:
	foo_impl = ctx.resolve(FooImpl)

I guess this can be solved by ignoring fields that has default factories.

nrbnlulu avatar Apr 11 '24 13:04 nrbnlulu

@nrbnlulu I think it's possible to just ignore parameters that have default arguments but honestly I'm not sure if they should be ignored or if container should try to resolve them 🤔

sig = inspect.signature(FooBase.__init__)
print(sig.parameters["bar"].default)  # <factory>

notypecheck avatar Apr 11 '24 21:04 notypecheck

I'll try to submit a PR soon

nrbnlulu avatar Apr 12 '24 15:04 nrbnlulu

@nrbnlulu Should be possible with #18

import dataclasses
from dataclasses import is_dataclass
from typing import Any

import aioinject
from aioinject import Provider
from aioinject.extensions.builtin import BuiltinDependencyExtractor
from aioinject.providers import Dependency


@dataclasses.dataclass(slots=True)
class FooBase[T]:
    bar: dict[T] = dataclasses.field(default_factory=dict)

    def baz(self):
        raise NotImplementedError()


class FooImpl(FooBase[int]):
    def baz(self):
        pass


class DataclassExtractor(BuiltinDependencyExtractor):
    def extract_supports(self, provider: Provider[Any]) -> bool:
        return super().extract_supports(provider) and is_dataclass(
            provider.impl
        )

    def extract_dependencies(
        self,
        provider: Provider[Any],
        context: dict[str, Any],
    ) -> tuple[Dependency[object], ...]:
        fields_with_defaults = [
            f.name
            for f in dataclasses.fields(provider.impl)
            if isinstance(f, dataclasses.Field)
            and f.default_factory != dataclasses.MISSING
        ]
        result = super().extract_dependencies(
            provider=provider, context=context
        )
        return tuple(r for r in result if r.name not in fields_with_defaults)


container = aioinject.Container(extensions=[DataclassExtractor()])
container.register(aioinject.Singleton(FooImpl))
with container.sync_context() as ctx:
    foo_impl = ctx.resolve(FooImpl)

notypecheck avatar Nov 21 '24 08:11 notypecheck

Ah this is very cool

nrbnlulu avatar Nov 21 '24 08:11 nrbnlulu