Typing State for the class-based API does not work
Following docs for action level typing for class-based actions does not work.
ref: https://github.com/DAGWorks-Inc/burr/issues/386
Current behavior
Example first action:
class SetInitialPromptAction(Action):
@property
def reads(self) -> list[str]:
return []
def run(self, state: ApplicationState, prompt: str) -> dict:
return {"initial_prompt": prompt}
@property
def writes(self) -> list[str]:
return ["initial_prompt"]
def update(self, result: dict, state: ApplicationState) -> ApplicationState:
prompt = result["initial_prompt"]
logger.info(f"Saving prompt to state: {prompt}")
state.initial_prompt = prompt
return state
@property
def inputs(self) -> list[str]:
return ["prompt"]
Example second action:
class ExtractSetAction(Action):
@property
def reads(self) -> list[str]:
return ["initial_prompt"]
def run(self, state: ApplicationState) -> dict:
logger.info(f"ApplicationState: {state}")
# Read prompt from state
prompt = state.initial_prompt
...
Logs: ApplicationState: {'initial_prompt': None}
Stack Traces
api | ********************************************************************************
api | -------------------------------------------------------------------
api | Oh no an error! Need help with Burr?
api | Join our discord and ask for help! https://discord.gg/4FxBMyzW5n
api | -------------------------------------------------------------------
api | > Action: extract_set encountered an error!<
api | > State (at time of action):
api | {'__PRIOR_STEP': 'set_prompt',
api | '__SEQUENCE_ID': 1,
api | 'initial_prompt': None,
api | 'set_from_prompt': None}
api | > Inputs (at time of action):
api | {'prompt': 'bicep curls with 22 pound dumbells for 21 reps'}
api | ********************************************************************************
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'
api | INFO: 192.168.65.1:35000 - "GET /api/extract-set?prompt=bicep%20curls%20with%2022%20pound%20dumbells%20for%2021%20reps HTTP/1.1" 500 Internal Server Error
api | ERROR: Exception in ASGI application
api | + Exception Group Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 77, in collapse_excgroups
api | | yield
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 186, in call
api | | async with anyio.create_task_group() as task_group:
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 736, in aexit
api | | raise BaseExceptionGroup(
api | | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
api | +-+---------------- 1 ----------------
api | | Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | | result = await app( # type: ignore[func-returns-value]
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | | return await self.app(scope, receive, send)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | | await super().call(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | | await self.app(scope, receive, _send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | | with collapse_excgroups():
api | | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | | self.gen.throw(typ, value, traceback)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | | response = await self.dispatch_func(request, call_next)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/main.py", line 36, in log_requests
api | | response = await call_next(request)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | | raise app_exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | | await self.app(scope, receive_or_disconnect, send_no_error)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | | await route.handle(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | | await self.app(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | | response = await f(request)
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | | raw_response = await run_endpoint_function(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | | return await run_in_threadpool(dependant.call, **values)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | | return await anyio.to_thread.run_sync(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | | return await get_async_backend().run_sync_in_worker_thread(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | | return await future
api | | ^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | | result = context.run(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/api/routes.py", line 53, in extract_set
api | | action, result, state = application.run(
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | | return call_fn(*args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | | next(gen)
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | | prior_action, result, state = self.step(inputs=inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | | out = self._step(inputs=inputs, _run_hooks=True)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | | raise e
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | | result = _run_function(
api | | ^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | | result = function.run(state_to_use, **inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | | prompt = state.initial_prompt
api | | ^^^^^^^^^^^^^^^^^^^^
api | | AttributeError: 'State' object has no attribute 'initial_prompt'
api | +------------------------------------
api |
api | During handling of the above exception, another exception occurred:
api |
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | result = await app( # type: ignore[func-returns-value]
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | return await self.app(scope, receive, send)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | await super().call(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | await self.app(scope, receive, _send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | with collapse_excgroups():
api | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | self.gen.throw(typ, value, traceback)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | response = await self.dispatch_func(request, call_next)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/main.py", line 36, in log_requests
api | response = await call_next(request)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | raise app_exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | await self.app(scope, receive_or_disconnect, send_no_error)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | await route.handle(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | await self.app(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | response = await f(request)
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | raw_response = await run_endpoint_function(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | return await run_in_threadpool(dependant.call, **values)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | return await anyio.to_thread.run_sync(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | return await get_async_backend().run_sync_in_worker_thread(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | return await future
api | ^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | result = context.run(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/api/routes.py", line 53, in extract_set
api | action, result, state = application.run(
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | return call_fn(*args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | next(gen)
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | prior_action, result, state = self.step(inputs=inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | out = self._step(inputs=inputs, _run_hooks=True)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | raise e
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'
Screenshots
(If applicable)
Steps to replicate behavior
Library & System Information
E.g. python version, burr library version, linux, etc.
- Python 3.11
- Debian bookworm slim
- Burr library:
burr = { extras = [
"graphviz",
"hamilton",
"streamlit",
"tracking-client",
"tracking-server",
], version = "^0.31.1" }
Expected behavior
To work the same as function-based actions
Additional context
Add any other context about the problem here.
I would like to follow this issue too. I would love to see typed state being treated as first class citizen in burr and making it work with class-based action is important imho.
Will scope out -- I think this is high value. That said @shun-liang -- you can always use centralized state -- allowing you to define the state model centrally with the application rather than decentrally with the class. https://burr.dagworks.io/concepts/state-typing/#application-level-typing
Will take a bit to scope but I think we can build this out reasonably fast.
Scoped out. Some changes but nothing crazy:
- Currently this is all buried in the decorator which I don’t like — see code: https://github.com/DAGWorks-Inc/burr/blob/2900e9e8e728aa8c78d58af053146ed2ee4ecab9/burr/integrations/pydantic.py#L170
- That said, actions can supply their own schema — these are just used for inspecting/providing information internally, not for anything else: https://github.com/DAGWorks-Inc/burr/blob/2900e9e8e728aa8c78d58af053146ed2ee4ecab9/burr/integrations/pydantic.py#L148
- I’m pretty sure we can put the code from the pydantic action before/after the run/update calls — this will be called during action execution/lifecycle — might require a “internal_run” and “internal_update” method
- With that, however, we can delegate to the schema’s methods, which means that classes will pick it up. All they have to do is expose the right way to do this.