di
di copied to clipboard
Add bind_by_generic
hey I was wondering if we can add a bind_by_generic function.
For example, it would be nice to be able to bind something like this:
import typing
KT = typing.TypeVar("KT")
VT = typing.TypeVar("VT")
class Record(typing.Generic[KT, VT]):
def __init__(self, key: KT, value: VT):
self.key = key
self.value = value
def func_hinted(record: Record[str, int]) -> Record[str, int]:
return record
def func_base(record: Record) -> Record:
return record
I came up with this function, and it works for my case:
import inspect
from typing import Any, get_origin
from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase
def bind_by_generic(
provider: DependentBase[Any],
dependency: type,
) -> BindHook:
"""Hook to substitute the matched dependency based on its generic."""
def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call == dependency:
return provider
if param is None:
return None
type_annotation_option = get_type(param)
if type_annotation_option is None:
return None
type_annotation = type_annotation_option.value
if get_origin(type_annotation) is dependency:
return provider
return None
return hook
I was wondering if it makes sense, or should it be part of bind_by_type, thoughts? I'm happy to send a PR if welcomed.