stories icon indicating copy to clipboard operation
stories copied to clipboard

Stories.

Open proofit404 opened this issue 10 months ago • 2 comments

from __future__ import annotations

from asyncio import run
from collections.abc import Awaitable
from collections.abc import Callable
from dataclasses import dataclass

from pydantic import BaseModel
from stories import Story

from aioapp.entities import Customer
from aioapp.entities import Order
from aioapp.repositories import create_payment
from aioapp.repositories import CustomerId
from aioapp.repositories import load_customer
from aioapp.repositories import load_order
from aioapp.repositories import OrderId


@dataclass
class Purchase:
    @property
    def purchase(self) -> Story:
        return (
            Story("purchase")
            .step(self.find_order)
            .step(self.find_customer)
            .step(self.check_balance)
            .step(self.persist_payment)
        )

    async def find_order(self, state: State) -> None:
        state.order = await self.load_order(state.order_id)

    async def find_customer(self, state: State) -> None:
        state.customer = await self.load_customer(state.customer_id)

    async def check_balance(self, state: State) -> None:
        if not state.order.affordable_for(state.customer):
            raise Exception

    async def persist_payment(self, state: State) -> None:
        state.payment = await self.create_payment(
            order_id=state.order_id,
            customer_id=state.customer_id,
        )

    load_order: Callable[[OrderId], Awaitable[Order]]
    load_customer: Callable[[CustomerId], Awaitable[Customer]]
    create_payment: Callable[[OrderId, CustomerId], Awaitable[None]]

    class State(BaseModel):
        order_id: OrderId
        customer_id: CustomerId

        order: Order
        customer: Customer


purchase = Purchase(
    load_order=load_order,
    load_customer=load_customer,
    create_payment=create_payment,
)


state = Purchase.State(order_id=1, customer_id=1)


run(purchase.purchase.call(state))
Story("name").step(one).when(is_ok).step(two).call({})

Story("name").step(one).unless(is_ok).step(two).call({})

Story("name").step(one).cond(is_ok).step(two).step(three).call({})

Story("name").step(one).nest(substory).call({})

Story("name").step(one).when(is_ok).nest(two_then).step(two_else).call({})

proofit404 avatar Apr 27 '25 20:04 proofit404

Hi! Here are my suggestions for the architecture of the Stories library, along with a working prototype. I recommend paying attention to the usage examples in the main function.

Highlights:

  • No context object pattern is used — this preserves natural method signatures, keeps the code more readable, and allows the methods to be used outside the Stories framework (see the example in the main function).

  • Mypy-friendly — supports type checking and IDE features such as autocomplete for return values and input state type annotations.

  • The State is based on a simple DI container; it can be a dataclass or an immutable container.

  • It’s possible to validate story consistency during application initialization — all steps can be checked to ensure that the type annotations allow the story to run.

I would strongly recommend make stories flat since conditional execution flow brings more complexity.

"""
Prerequisites:
  - Python 3.10+
  - `kink` library installed (`pip install kink`)
"""

import logging
from asyncio import run
from collections.abc import Awaitable, Callable
from dataclasses import Field, asdict, dataclass
from decimal import Decimal
from typing import (
    Any,
    ClassVar,
    Generic,
    NewType,
    Protocol,
    Self,
    TypeVar,
    cast,
    get_type_hints,
    runtime_checkable,
)

from kink import Container

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
)

CustomerID = NewType('CustomerID', str)
OrderID = NewType('OrderID', str)
StateType = dict[str | type, Any]


@runtime_checkable
class DataclassLike(Protocol):
    __dataclass_fields__: ClassVar[dict[str, Field[Any]]]



class StoryState(Container):
    """
    A container to hold the state of the story.
    It can be used to pass data between steps in the story.
    """

    def __init__(self, items: StateType | None = None, /) -> None:
        super().__init__()
        if items is not None:
            for key, value in items.items():
                self[key] = value

    def to_dict(self) -> dict[str | type, Any]:
        """
        Convert the container to a dictionary.
        This is useful for logging or debugging.
        """
        return self._services

    @classmethod
    def from_dataclass(cls, obj: 'DataclassLike') -> Self:
        return cls(cast('StateType', asdict(obj)))


@dataclass
class Customer:
    id: CustomerID
    name: str
    balance: Decimal


@dataclass
class Order:
    id: OrderID
    amount: Decimal
    description: str

    def affordable_for(self, customer: 'Customer') -> bool:
        return customer.balance >= self.amount


@dataclass
class Payment:
    order_id: OrderID
    customer_id: CustomerID
    amount: Decimal


async def load_order(order_id: OrderID) -> Order:
    # Simulate loading an order from a database or external service
    return Order(id=order_id, amount=Decimal('100.0'), description='Sample Order')


async def load_customer(customer_id: CustomerID) -> Customer:
    # Simulate loading a customer from a database or external service
    return Customer(id=customer_id, name='John Doe', balance=Decimal('150.0'))


async def create_payment(order_id: OrderID, customer_id: CustomerID) -> Payment:
    # Simulate creating a payment in a database or external service
    return Payment(
        order_id=order_id,
        customer_id=customer_id,
        amount=Decimal('100.0'),
    )


StoryReturnT = TypeVar('StoryReturnT')
StoryStateT = TypeVar('StoryStateT', bound=StoryState | DataclassLike)


class StoryRunner(Generic[StoryReturnT, StoryStateT]):
    def __init__(
        self,
        name: str | None = None,
        steps: tuple[Callable[..., Awaitable[Any]], ...] = (),
    ) -> None:
        self.name = name
        self.steps: list[Callable[..., Awaitable[Any]]] = list(steps)
        self._story_instance: Any = None

    def step(self, func: Callable[..., Awaitable[Any]]) -> Self:
        self.steps.append(func)
        return self

    def __set_name__(self, owner: type, name: str) -> None:
        if not self.name:
            self.name = f'{owner.__name__}.{name}'

    def __get__(self, instance: Any, owner: type) -> Self:  # noqa: ANN401
        self._story_instance = instance or owner
        return self

    async def __call__(self, initial_state: StoryStateT) -> StoryReturnT:  # noqa: C901
        if isinstance(initial_state, DataclassLike):
            _initial_state = StoryState.from_dataclass(initial_state)
        elif isinstance(initial_state, StoryState):
            _initial_state = initial_state
        else:
            msg = f'Expected StoryState or dataclass instance, got {type(initial_state)}'
            raise TypeError(msg)

        logger.info(
            'Running story: %s with initial state: %s',
            self.name, _initial_state.to_dict(),
        )
        story_result: StoryReturnT | None = None

        for step in self.steps:
            type_hints = get_type_hints(step)
            return_type = type_hints.pop('return', None)

            step_kwargs = {}
            for param, param_type in type_hints.items():
                if param_type in _initial_state:
                    step_kwargs[param] = _initial_state[param_type]

                if param in _initial_state:
                    if type(param_type) is NewType and not isinstance(
                        _initial_state[param],
                        cast('type | tuple[type, ...]', param_type.__supertype__),
                    ):
                        logger.warning(
                            'Parameter "%s" has type %s but is provided as %s',
                            param,
                            param_type,
                            type(_initial_state[param]),
                        )
                    step_kwargs[param] = _initial_state[param]

            logger.info('Executing step: %s with args: %s', step.__qualname__, step_kwargs)
            if self._story_instance is not None:
                step_kwargs['self'] = self._story_instance
            result = await step(**step_kwargs)

            if not (return_type is None or return_type is type(None)):
                _initial_state[return_type] = result
                story_result = result

        if story_result is None:
            msg = f'Story {self.name} did not return a result.'
            raise RuntimeError(msg)

        return story_result


class CreatePaymentProtocol(Protocol):
    async def __call__(self, order_id: OrderID, customer_id: CustomerID) -> Payment: ...


class LoadOrderProtocol(Protocol):
    async def __call__(self, order_id: OrderID) -> Order: ...


class LoadCustomerProtocol(Protocol):
    async def __call__(self, customer_id: CustomerID) -> Customer: ...



@dataclass
class State:
    order_id: str
    customer_id: CustomerID


@dataclass
class PurchaseStory:
    load_order: LoadOrderProtocol
    load_customer: LoadCustomerProtocol
    create_payment: CreatePaymentProtocol

    async def find_order(self, *, order_id: OrderID) -> Order:
        return await self.load_order(order_id)

    async def find_customer(self, *, customer_id: CustomerID) -> Customer:
        return await self.load_customer(customer_id)

    async def check_balance(self, *, order: Order, customer: Customer) -> None:
        if not order.affordable_for(customer):
            msg = (
                f'Customer {customer.id} cannot afford order {order.id}. '
                f'Order amount: {order.amount}, Customer balance: {customer.balance}'
            )
            raise RuntimeError(msg)

    async def persist_payment(self, *, order: Order, customer: Customer) -> Payment:
        return await self.create_payment(
            order_id=order.id,
            customer_id=customer.id,
        )

    # possible to validate story consistency with type hints
    # on application load (once)
    run_cls = StoryRunner[Payment, StoryState](
        steps=(
            find_order,
            find_customer,
            check_balance,
            persist_payment,
        ),
    )

    # also works but cannot validate consistency on application load
    @property
    def run(self) -> StoryRunner[Payment, State]:
        return StoryRunner[Payment, State](
            'purchase_story',
            steps=(
                self.find_order,
                self.find_customer,
                self.check_balance,
                self.persist_payment,
            ),
        )

purchase_story = PurchaseStory(
    load_order=load_order,
    load_customer=load_customer,
    create_payment=create_payment,
)


@dataclass
class AnotherState:
    order_id: int
    customer_id: int


async def main() -> None:
    # Fill initial state with annotations
    state = StoryState({
        OrderID: OrderID('some_order_id'),
        CustomerID: CustomerID('some_customer_id'),
    })

    result1 = await purchase_story.run_cls(state)
    logger.info('Payment created: %s', result1)

    # Fill initial state with parameter names
    another_state = StoryState({
        'order_id': OrderID('some_order_id'),
        'customer_id': CustomerID('some_customer_id'),
    })
    result2 = await purchase_story.run_cls(another_state)

    # Since state is mutable we can access it after the run and get the created entities
    logger.info('Order created: %s', another_state[Order])
    logger.info('Customer created: %s', another_state[Customer])
    logger.info('Payment created: %s', result2)

    assert result1 == result2  # noqa: S101

    # Support dataclass state for useful autocompletion
    state_dc = State(
        order_id='some_order_id',
        customer_id=CustomerID('some_customer_id'),
    )
    result3 = await purchase_story.run(state_dc)

    logger.info('Payment created: %s', result3)
    assert result1 == result2 == result3  # noqa: S101

    # Able to use code as is without stories
    logger.info('Running without story runner...')
    order = await purchase_story.find_order(order_id=OrderID('some_order_id'))
    customer = await purchase_story.find_customer(customer_id=CustomerID('some_customer_id'))
    await purchase_story.check_balance(order=order, customer=customer)
    result4 = await purchase_story.persist_payment(order=order, customer=customer)

    assert result1 == result2 == result3 == result4  # noqa: S101


run(main())

Bahus avatar Jun 17 '25 09:06 Bahus

Hello,

Thank you for thoughtful suggestion. I appreciate you effort you put into design.

Over the years I came to simple approach in the way I write use cases.

@dataclass
class Purchase:
    find_order: Callable[[OrderID], Awaitable[Order]]
    find_customer: Callable[[CustomerID], Awaitable[Customer]]
    persist_payment: Callable[[Order, Customer], Awaitable[Payment]]

    async def make(self, order_id: OrderID, customer_id: CustomerID) -> Payment:
        order = await self.find_order(order_id)
        customer = await self.find_customer(customer_id)
        order.ensure_affordable_for(customer)
        return await self.persist_payment(order, customer)

I'm thinking of a decorator or wrapper we could extract stories-like workflow log from this execution:

Purchase.make:
  find_order
  find_customer
  ensure_affordable_for
  persist_payment

order_id = 123  # argument
customer_id = 456 # argument
order = Order(123)  # set by find_order
customer = Customer(456)  # set by find_customer
Payment(789)  # returned

proofit404 avatar Jun 19 '25 00:06 proofit404