Stories.
from __future__ import annotations
from asyncio import run
from collections.abc import Awaitable
from collections.abc import Callable
from dataclasses import dataclass
from pydantic import BaseModel
from stories import Story
from aioapp.entities import Customer
from aioapp.entities import Order
from aioapp.repositories import create_payment
from aioapp.repositories import CustomerId
from aioapp.repositories import load_customer
from aioapp.repositories import load_order
from aioapp.repositories import OrderId
@dataclass
class Purchase:
@property
def purchase(self) -> Story:
return (
Story("purchase")
.step(self.find_order)
.step(self.find_customer)
.step(self.check_balance)
.step(self.persist_payment)
)
async def find_order(self, state: State) -> None:
state.order = await self.load_order(state.order_id)
async def find_customer(self, state: State) -> None:
state.customer = await self.load_customer(state.customer_id)
async def check_balance(self, state: State) -> None:
if not state.order.affordable_for(state.customer):
raise Exception
async def persist_payment(self, state: State) -> None:
state.payment = await self.create_payment(
order_id=state.order_id,
customer_id=state.customer_id,
)
load_order: Callable[[OrderId], Awaitable[Order]]
load_customer: Callable[[CustomerId], Awaitable[Customer]]
create_payment: Callable[[OrderId, CustomerId], Awaitable[None]]
class State(BaseModel):
order_id: OrderId
customer_id: CustomerId
order: Order
customer: Customer
purchase = Purchase(
load_order=load_order,
load_customer=load_customer,
create_payment=create_payment,
)
state = Purchase.State(order_id=1, customer_id=1)
run(purchase.purchase.call(state))
Story("name").step(one).when(is_ok).step(two).call({})
Story("name").step(one).unless(is_ok).step(two).call({})
Story("name").step(one).cond(is_ok).step(two).step(three).call({})
Story("name").step(one).nest(substory).call({})
Story("name").step(one).when(is_ok).nest(two_then).step(two_else).call({})
Hi! Here are my suggestions for the architecture of the Stories library, along with a working prototype. I recommend paying attention to the usage examples in the main function.
Highlights:
-
No context object pattern is used — this preserves natural method signatures, keeps the code more readable, and allows the methods to be used outside the Stories framework (see the example in the main function).
-
Mypy-friendly — supports type checking and IDE features such as autocomplete for return values and input state type annotations.
-
The State is based on a simple DI container; it can be a dataclass or an immutable container.
-
It’s possible to validate story consistency during application initialization — all steps can be checked to ensure that the type annotations allow the story to run.
I would strongly recommend make stories flat since conditional execution flow brings more complexity.
"""
Prerequisites:
- Python 3.10+
- `kink` library installed (`pip install kink`)
"""
import logging
from asyncio import run
from collections.abc import Awaitable, Callable
from dataclasses import Field, asdict, dataclass
from decimal import Decimal
from typing import (
Any,
ClassVar,
Generic,
NewType,
Protocol,
Self,
TypeVar,
cast,
get_type_hints,
runtime_checkable,
)
from kink import Container
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
)
CustomerID = NewType('CustomerID', str)
OrderID = NewType('OrderID', str)
StateType = dict[str | type, Any]
@runtime_checkable
class DataclassLike(Protocol):
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
class StoryState(Container):
"""
A container to hold the state of the story.
It can be used to pass data between steps in the story.
"""
def __init__(self, items: StateType | None = None, /) -> None:
super().__init__()
if items is not None:
for key, value in items.items():
self[key] = value
def to_dict(self) -> dict[str | type, Any]:
"""
Convert the container to a dictionary.
This is useful for logging or debugging.
"""
return self._services
@classmethod
def from_dataclass(cls, obj: 'DataclassLike') -> Self:
return cls(cast('StateType', asdict(obj)))
@dataclass
class Customer:
id: CustomerID
name: str
balance: Decimal
@dataclass
class Order:
id: OrderID
amount: Decimal
description: str
def affordable_for(self, customer: 'Customer') -> bool:
return customer.balance >= self.amount
@dataclass
class Payment:
order_id: OrderID
customer_id: CustomerID
amount: Decimal
async def load_order(order_id: OrderID) -> Order:
# Simulate loading an order from a database or external service
return Order(id=order_id, amount=Decimal('100.0'), description='Sample Order')
async def load_customer(customer_id: CustomerID) -> Customer:
# Simulate loading a customer from a database or external service
return Customer(id=customer_id, name='John Doe', balance=Decimal('150.0'))
async def create_payment(order_id: OrderID, customer_id: CustomerID) -> Payment:
# Simulate creating a payment in a database or external service
return Payment(
order_id=order_id,
customer_id=customer_id,
amount=Decimal('100.0'),
)
StoryReturnT = TypeVar('StoryReturnT')
StoryStateT = TypeVar('StoryStateT', bound=StoryState | DataclassLike)
class StoryRunner(Generic[StoryReturnT, StoryStateT]):
def __init__(
self,
name: str | None = None,
steps: tuple[Callable[..., Awaitable[Any]], ...] = (),
) -> None:
self.name = name
self.steps: list[Callable[..., Awaitable[Any]]] = list(steps)
self._story_instance: Any = None
def step(self, func: Callable[..., Awaitable[Any]]) -> Self:
self.steps.append(func)
return self
def __set_name__(self, owner: type, name: str) -> None:
if not self.name:
self.name = f'{owner.__name__}.{name}'
def __get__(self, instance: Any, owner: type) -> Self: # noqa: ANN401
self._story_instance = instance or owner
return self
async def __call__(self, initial_state: StoryStateT) -> StoryReturnT: # noqa: C901
if isinstance(initial_state, DataclassLike):
_initial_state = StoryState.from_dataclass(initial_state)
elif isinstance(initial_state, StoryState):
_initial_state = initial_state
else:
msg = f'Expected StoryState or dataclass instance, got {type(initial_state)}'
raise TypeError(msg)
logger.info(
'Running story: %s with initial state: %s',
self.name, _initial_state.to_dict(),
)
story_result: StoryReturnT | None = None
for step in self.steps:
type_hints = get_type_hints(step)
return_type = type_hints.pop('return', None)
step_kwargs = {}
for param, param_type in type_hints.items():
if param_type in _initial_state:
step_kwargs[param] = _initial_state[param_type]
if param in _initial_state:
if type(param_type) is NewType and not isinstance(
_initial_state[param],
cast('type | tuple[type, ...]', param_type.__supertype__),
):
logger.warning(
'Parameter "%s" has type %s but is provided as %s',
param,
param_type,
type(_initial_state[param]),
)
step_kwargs[param] = _initial_state[param]
logger.info('Executing step: %s with args: %s', step.__qualname__, step_kwargs)
if self._story_instance is not None:
step_kwargs['self'] = self._story_instance
result = await step(**step_kwargs)
if not (return_type is None or return_type is type(None)):
_initial_state[return_type] = result
story_result = result
if story_result is None:
msg = f'Story {self.name} did not return a result.'
raise RuntimeError(msg)
return story_result
class CreatePaymentProtocol(Protocol):
async def __call__(self, order_id: OrderID, customer_id: CustomerID) -> Payment: ...
class LoadOrderProtocol(Protocol):
async def __call__(self, order_id: OrderID) -> Order: ...
class LoadCustomerProtocol(Protocol):
async def __call__(self, customer_id: CustomerID) -> Customer: ...
@dataclass
class State:
order_id: str
customer_id: CustomerID
@dataclass
class PurchaseStory:
load_order: LoadOrderProtocol
load_customer: LoadCustomerProtocol
create_payment: CreatePaymentProtocol
async def find_order(self, *, order_id: OrderID) -> Order:
return await self.load_order(order_id)
async def find_customer(self, *, customer_id: CustomerID) -> Customer:
return await self.load_customer(customer_id)
async def check_balance(self, *, order: Order, customer: Customer) -> None:
if not order.affordable_for(customer):
msg = (
f'Customer {customer.id} cannot afford order {order.id}. '
f'Order amount: {order.amount}, Customer balance: {customer.balance}'
)
raise RuntimeError(msg)
async def persist_payment(self, *, order: Order, customer: Customer) -> Payment:
return await self.create_payment(
order_id=order.id,
customer_id=customer.id,
)
# possible to validate story consistency with type hints
# on application load (once)
run_cls = StoryRunner[Payment, StoryState](
steps=(
find_order,
find_customer,
check_balance,
persist_payment,
),
)
# also works but cannot validate consistency on application load
@property
def run(self) -> StoryRunner[Payment, State]:
return StoryRunner[Payment, State](
'purchase_story',
steps=(
self.find_order,
self.find_customer,
self.check_balance,
self.persist_payment,
),
)
purchase_story = PurchaseStory(
load_order=load_order,
load_customer=load_customer,
create_payment=create_payment,
)
@dataclass
class AnotherState:
order_id: int
customer_id: int
async def main() -> None:
# Fill initial state with annotations
state = StoryState({
OrderID: OrderID('some_order_id'),
CustomerID: CustomerID('some_customer_id'),
})
result1 = await purchase_story.run_cls(state)
logger.info('Payment created: %s', result1)
# Fill initial state with parameter names
another_state = StoryState({
'order_id': OrderID('some_order_id'),
'customer_id': CustomerID('some_customer_id'),
})
result2 = await purchase_story.run_cls(another_state)
# Since state is mutable we can access it after the run and get the created entities
logger.info('Order created: %s', another_state[Order])
logger.info('Customer created: %s', another_state[Customer])
logger.info('Payment created: %s', result2)
assert result1 == result2 # noqa: S101
# Support dataclass state for useful autocompletion
state_dc = State(
order_id='some_order_id',
customer_id=CustomerID('some_customer_id'),
)
result3 = await purchase_story.run(state_dc)
logger.info('Payment created: %s', result3)
assert result1 == result2 == result3 # noqa: S101
# Able to use code as is without stories
logger.info('Running without story runner...')
order = await purchase_story.find_order(order_id=OrderID('some_order_id'))
customer = await purchase_story.find_customer(customer_id=CustomerID('some_customer_id'))
await purchase_story.check_balance(order=order, customer=customer)
result4 = await purchase_story.persist_payment(order=order, customer=customer)
assert result1 == result2 == result3 == result4 # noqa: S101
run(main())
Hello,
Thank you for thoughtful suggestion. I appreciate you effort you put into design.
Over the years I came to simple approach in the way I write use cases.
@dataclass
class Purchase:
find_order: Callable[[OrderID], Awaitable[Order]]
find_customer: Callable[[CustomerID], Awaitable[Customer]]
persist_payment: Callable[[Order, Customer], Awaitable[Payment]]
async def make(self, order_id: OrderID, customer_id: CustomerID) -> Payment:
order = await self.find_order(order_id)
customer = await self.find_customer(customer_id)
order.ensure_affordable_for(customer)
return await self.persist_payment(order, customer)
I'm thinking of a decorator or wrapper we could extract stories-like workflow log from this execution:
Purchase.make:
find_order
find_customer
ensure_affordable_for
persist_payment
order_id = 123 # argument
customer_id = 456 # argument
order = Order(123) # set by find_order
customer = Customer(456) # set by find_customer
Payment(789) # returned