di icon indicating copy to clipboard operation
di copied to clipboard

Add bind_by_generic

Open woile opened this issue 1 year ago • 0 comments

hey I was wondering if we can add a bind_by_generic function.

For example, it would be nice to be able to bind something like this:

import typing

KT = typing.TypeVar("KT")
VT = typing.TypeVar("VT")

class Record(typing.Generic[KT, VT]):
    def __init__(self, key: KT, value: VT):
        self.key = key
        self.value = value

def func_hinted(record: Record[str, int]) -> Record[str, int]:
    return record

def func_base(record: Record) -> Record:
    return record

I came up with this function, and it works for my case:

import inspect
from typing import Any, get_origin

from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase


def bind_by_generic(
    provider: DependentBase[Any],
    dependency: type,
) -> BindHook:
    """Hook to substitute the matched dependency based on its generic."""

    def hook(
        param: inspect.Parameter | None, dependent: DependentBase[Any]
    ) -> DependentBase[Any] | None:
        if dependent.call == dependency:
            return provider
        if param is None:
            return None

        type_annotation_option = get_type(param)
        if type_annotation_option is None:
            return None
        type_annotation = type_annotation_option.value
        if get_origin(type_annotation) is dependency:
            return provider
        return None

    return hook

I was wondering if it makes sense, or should it be part of bind_by_type, thoughts? I'm happy to send a PR if welcomed.

woile avatar Oct 23 '24 13:10 woile