burr icon indicating copy to clipboard operation
burr copied to clipboard

Typing State for the class-based API does not work

Open mdrideout opened this issue 1 year ago • 3 comments

Following docs for action level typing for class-based actions does not work.

ref: https://github.com/DAGWorks-Inc/burr/issues/386

Current behavior

Example first action:

class SetInitialPromptAction(Action):
    @property
    def reads(self) -> list[str]:
        return []

    def run(self, state: ApplicationState, prompt: str) -> dict:
        return {"initial_prompt": prompt}

    @property
    def writes(self) -> list[str]:
        return ["initial_prompt"]

    def update(self, result: dict, state: ApplicationState) -> ApplicationState:
        prompt = result["initial_prompt"]
        logger.info(f"Saving prompt to state: {prompt}")
        state.initial_prompt = prompt
        return state

    @property
    def inputs(self) -> list[str]:
        return ["prompt"]

Example second action:

class ExtractSetAction(Action):
    @property
    def reads(self) -> list[str]:
        return ["initial_prompt"]

    def run(self, state: ApplicationState) -> dict:
        logger.info(f"ApplicationState: {state}")

        # Read prompt from state
        prompt = state.initial_prompt
        ...

Logs: ApplicationState: {'initial_prompt': None}

Stack Traces

api | ******************************************************************************** api | ------------------------------------------------------------------- api | Oh no an error! Need help with Burr? api | Join our discord and ask for help! https://discord.gg/4FxBMyzW5n api | ------------------------------------------------------------------- api | > Action: extract_set encountered an error!< api | > State (at time of action): api | {'__PRIOR_STEP': 'set_prompt', api | '__SEQUENCE_ID': 1, api | 'initial_prompt': None, api | 'set_from_prompt': None} api | > Inputs (at time of action): api | {'prompt': 'bicep curls with 22 pound dumbells for 21 reps'} api | ******************************************************************************** api | Traceback (most recent call last): api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | result = _run_function( api | ^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | result = function.run(state_to_use, **inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | prompt = state.initial_prompt api | ^^^^^^^^^^^^^^^^^^^^ api | AttributeError: 'State' object has no attribute 'initial_prompt' api | INFO: 192.168.65.1:35000 - "GET /api/extract-set?prompt=bicep%20curls%20with%2022%20pound%20dumbells%20for%2021%20reps HTTP/1.1" 500 Internal Server Error api | ERROR: Exception in ASGI application api | + Exception Group Traceback (most recent call last): api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 77, in collapse_excgroups api | | yield api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 186, in call api | | async with anyio.create_task_group() as task_group: api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 736, in aexit api | | raise BaseExceptionGroup( api | | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception) api | +-+---------------- 1 ---------------- api | | Traceback (most recent call last): api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi api | | result = await app( # type: ignore[func-returns-value] api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call api | | return await self.app(scope, receive, send) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call api | | await super().call(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call api | | await self.middleware_stack(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call api | | await self.app(scope, receive, _send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call api | | with collapse_excgroups(): api | | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit api | | self.gen.throw(typ, value, traceback) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call api | | response = await self.dispatch_func(request, call_next) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/main.py", line 36, in log_requests api | | response = await call_next(request) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next api | | raise app_exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro api | | await self.app(scope, receive_or_disconnect, send_no_error) api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call api | | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | | await app(scope, receive, sender) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call api | | await self.middleware_stack(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app api | | await route.handle(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle api | | await self.app(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app api | | await wrap_app_handling_exceptions(app, request)(scope, receive, send) api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | | raise exc api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | | await app(scope, receive, sender) api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app api | | response = await f(request) api | | ^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app api | | raw_response = await run_endpoint_function( api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function api | | return await run_in_threadpool(dependant.call, **values) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool api | | return await anyio.to_thread.run_sync(func, *args) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync api | | return await get_async_backend().run_sync_in_worker_thread( api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread api | | return await future api | | ^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run api | | result = context.run(func, *args) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/api/routes.py", line 53, in extract_set api | | action, result, state = application.run( api | | ^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn api | | return call_fn(*args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | | return fn(app_self, *args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run api | | next(gen) api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate api | | prior_action, result, state = self.step(inputs=inputs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | | return fn(app_self, *args, **kwargs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step api | | out = self._step(inputs=inputs, _run_hooks=True) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step api | | raise e api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | | result = _run_function( api | | ^^^^^^^^^^^^^^ api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | | result = function.run(state_to_use, **inputs) api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | | prompt = state.initial_prompt api | | ^^^^^^^^^^^^^^^^^^^^ api | | AttributeError: 'State' object has no attribute 'initial_prompt' api | +------------------------------------ api | api | During handling of the above exception, another exception occurred: api | api | Traceback (most recent call last): api | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi api | result = await app( # type: ignore[func-returns-value] api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call api | return await self.app(scope, receive, send) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call api | await super().call(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call api | await self.middleware_stack(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call api | await self.app(scope, receive, _send) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call api | with collapse_excgroups(): api | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit api | self.gen.throw(typ, value, traceback) api | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call api | response = await self.dispatch_func(request, call_next) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/main.py", line 36, in log_requests api | response = await call_next(request) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next api | raise app_exc api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro api | await self.app(scope, receive_or_disconnect, send_no_error) api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call api | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | await app(scope, receive, sender) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call api | await self.middleware_stack(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app api | await route.handle(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle api | await self.app(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app api | await wrap_app_handling_exceptions(app, request)(scope, receive, send) api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app api | raise exc api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app api | await app(scope, receive, sender) api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app api | response = await f(request) api | ^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app api | raw_response = await run_endpoint_function( api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function api | return await run_in_threadpool(dependant.call, **values) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool api | return await anyio.to_thread.run_sync(func, *args) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync api | return await get_async_backend().run_sync_in_worker_thread( api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread api | return await future api | ^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run api | result = context.run(func, *args) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/api/routes.py", line 53, in extract_set api | action, result, state = application.run( api | ^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn api | return call_fn(*args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | return fn(app_self, *args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run api | next(gen) api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate api | prior_action, result, state = self.step(inputs=inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync api | return fn(app_self, *args, **kwargs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step api | out = self._step(inputs=inputs, _run_hooks=True) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step api | raise e api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step api | result = _run_function( api | ^^^^^^^^^^^^^^ api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function api | result = function.run(state_to_use, **inputs) api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run api | prompt = state.initial_prompt api | ^^^^^^^^^^^^^^^^^^^^ api | AttributeError: 'State' object has no attribute 'initial_prompt'

Screenshots

(If applicable)

Steps to replicate behavior

Library & System Information

E.g. python version, burr library version, linux, etc.

  • Python 3.11
  • Debian bookworm slim
  • Burr library:
burr = { extras = [
  "graphviz",
  "hamilton",
  "streamlit",
  "tracking-client",
  "tracking-server",
], version = "^0.31.1" }

Expected behavior

To work the same as function-based actions

Additional context

Add any other context about the problem here.

mdrideout avatar Oct 16 '24 23:10 mdrideout

I would like to follow this issue too. I would love to see typed state being treated as first class citizen in burr and making it work with class-based action is important imho.

shun-liang avatar Dec 01 '24 15:12 shun-liang

Will scope out -- I think this is high value. That said @shun-liang -- you can always use centralized state -- allowing you to define the state model centrally with the application rather than decentrally with the class. https://burr.dagworks.io/concepts/state-typing/#application-level-typing

Will take a bit to scope but I think we can build this out reasonably fast.

elijahbenizzy avatar Dec 02 '24 23:12 elijahbenizzy

Scoped out. Some changes but nothing crazy:

  1. Currently this is all buried in the decorator which I don’t like — see code: https://github.com/DAGWorks-Inc/burr/blob/2900e9e8e728aa8c78d58af053146ed2ee4ecab9/burr/integrations/pydantic.py#L170
  2. That said, actions can supply their own schema — these are just used for inspecting/providing information internally, not for anything else: https://github.com/DAGWorks-Inc/burr/blob/2900e9e8e728aa8c78d58af053146ed2ee4ecab9/burr/integrations/pydantic.py#L148
  3. I’m pretty sure we can put the code from the pydantic action before/after the run/update calls — this will be called during action execution/lifecycle — might require a “internal_run” and “internal_update” method
  4. With that, however, we can delegate to the schema’s methods, which means that classes will pick it up. All they have to do is expose the right way to do this.

elijahbenizzy avatar Jan 20 '25 05:01 elijahbenizzy