hamilton icon indicating copy to clipboard operation
hamilton copied to clipboard

Accessing node and graph metadata/context from within a node

Open dmexs opened this issue 10 months ago • 3 comments

I'd like to be able to access node metadata (task_id, node_name, and other arbitrary parameters) from within my node functions. My particular use case is writing an async compatible MlFlowTracker plugin. (I have had intermittent sucess with the the mlflow Fluent API trace decorator and would rather use the more explicit client API.)

From within a node function I'd like a way to access context passed in from plugin hooks.

Approach 1 - contextvars Working with @elijahbenizzy, I tried contextvars but found that the context wasn't maintained between the pre_node_execute hook, node function and the post_node_execute hook. See reproduction here: https://github.com/dmexs/hamilton-contextvar-repro

Approach 2 - modifying kwargs and parameter drilling Instead I took the approach of injecting a context object into node kwargs which works but requires parameter/prop drilling which I don't love. I'd prefer to be able to retrieve passed data from a contextvar anywhere within the node function execution stack. In my case I'm retrieving my mlflow span/request ids to create subspans on API calls.

Here's what I have for approach #2 which works:

mlflow_tracker.py

from hamilton.lifecycle import base
from hamilton import node, graph
import logging
from typing import Type, Any, Optional, Dict, List
import mlflow
import pickle
import contextvars
from pydantic import BaseModel
from contextlib import asynccontextmanager

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

class NodeContext(BaseModel):
    mlflow_request_id: Optional[str] = None
    mlflow_root_span_id: Optional[str] = None

class MLFlowTrackerAsync(
    base.BasePreNodeExecuteAsync,
    base.BasePostNodeExecuteAsync,
    base.BasePreGraphExecuteAsync,
    base.BasePostGraphExecuteAsync,
):
    mlflow_run_info = None
    mlflow_client: Optional[mlflow.MlflowClient] = None
    mlflow_request_id: Optional[str] = None
    mlflow_root_span_id: Optional[str] = None
    mlflow_node_span_id_cache = {}

    def __init__(self, *args, **kwargs):
        if mlflow.active_run():
            self.mlflow_client = mlflow.MlflowClient()
            self.mlflow_run_info = mlflow.active_run().info
            logger.debug(f"MLFlowTrackerAsync: {self.mlflow_run_info.run_id}")

    async def pre_graph_execute(
        self,
        *,
        run_id: str,
        graph: "graph.FunctionGraph",
        final_vars: List[str],
        inputs: Dict[str, Any],
        overrides: Dict[str, Any],
    ):
        if self.mlflow_run_info:
            trace = self.mlflow_client.start_trace(
                name=f'hamilton',
                inputs=inputs,
                experiment_id=self.mlflow_run_info.experiment_id,
                span_type=mlflow.entities.SpanType.CHAIN,
            )
            self.mlflow_request_id = trace.request_id
            self.mlflow_root_span_id = trace.span_id

    async def post_graph_execute(
        self,
        *,
        run_id: str,
        graph: "graph.FunctionGraph",
        success: bool,
        error: Optional[Exception],
        results: Optional[Dict[str, Any]],
    ):
        if self.mlflow_request_id:
            self.mlflow_client.end_trace(
                request_id=self.mlflow_request_id,
                outputs=results
            )

    async def pre_node_execute(
        self,
        run_id: str,
        node_: node.Node,
        kwargs: Dict[str, Any],
        task_id: Optional[str] = None
    ) -> None:
        context = kwargs.get('__context', NodeContext())
        if self.mlflow_request_id and self.mlflow_root_span_id:
            span = self.mlflow_client.start_span(
                name=node_.name,
                span_type=mlflow.entities.SpanType.CHAIN,
                request_id=self.mlflow_request_id,
                parent_id=self.mlflow_root_span_id,
                inputs=kwargs,
            )
            self.mlflow_node_span_id_cache[task_id] = span.span_id
            context.mlflow_request_id = span.request_id
            context.mlflow_root_span_id = span.span_id

        if '__context' in node_.input_types:
            kwargs['__context'] = context

    async def post_node_execute(
        self,
        run_id: str,
        node_: node.Node,
        success: bool,
        error: Optional[Exception],
        result: Any,
        task_id: Optional[str] = None,
        **future_kwargs: dict,
    ) -> None:
        if self.mlflow_node_span_id_cache.get(task_id, None):
            self.mlflow_client.end_span(
                request_id=self.mlflow_request_id,
                span_id=self.mlflow_node_span_id_cache[task_id],
                outputs=result
            )

dag.py

async def patient_fhir_id(hero_api: hero_api_async.HeroAPI, uid: str, __context: Optional[NodeContext] = None) -> str:
    """
    Lookup patient
    """
    async with api as client:
        patient = await client.request(endpoint='fhir/patient', params={"id": uid, "idType": "UID"}, context=__context)
    fhir_id = patient['id']
    return fhir_id

client.py

    async def request(
        self,
        endpoint,
        method = 'GET',
        params: Optional[dict] = None,
        context: Optional[NodeContext] = None
    ):       
        span = None
        return_response = None
        if context and context.mlflow_request_id and context.mlflow_root_span_id:
            span = self.mlflow_client.start_span(
                name='API.request',
                span_type=mlflow.entities.SpanType.RETRIEVER,
                request_id=context.mlflow_request_id,
                parent_id=context.mlflow_root_span_id,
                inputs={
                    "method": method,
                    "endpoint": endpoint,
                    "params": params
                }
            )


        try:
            url = self.BASE_URL + endpoint
            async with self.session.request(
                method=method,
                url=url,
                params=params,
            ) as response:
                response.raise_for_status()  # Raise an HTTPError for bad responses
                return_response = await response.json(content_type=None)  # Return the response data as a dictionary
                
            
                return return_response

        except Exception as e:
            logger.error(f"{method} {response.status} {url} - {e}")
            return_response = e
            raise e
        
        finally:
            if span:
                self.mlflow_client.end_span(
                    request_id=span.request_id,
                    span_id=span.span_id,
                    outputs=return_response
                )

dmexs avatar Feb 26 '25 20:02 dmexs

OK, I want to mirror the way burr does this:

https://burr.dagworks.io/reference/application/#burr.core.application.ApplicationContext

Any node can declare __context and have it passed in as the current context variable.

Currently I'm thinking of the following schema:

@dataclasses.dataclass
class Context:
    node_id: str
    task_id: str
    run_id: str

This would be declared as __context in a fn:

def foo(bar: str, __context: DataflowContext) -> ...:
    assert __context.node_id == "foo"
    ...

elijahbenizzy avatar Feb 26 '25 21:02 elijahbenizzy

Looks good!

I'd love to see a working contextvar approach that doesn't require passing the parameter deeply into nested functions that need it, but this works just fine too!

dmexs avatar Feb 26 '25 21:02 dmexs

Yep! I think we can explore that as well (a little bit of odd stuff around parallelism/async, but I need to scope it out).

elijahbenizzy avatar Feb 27 '25 15:02 elijahbenizzy