Accessing node and graph metadata/context from within a node
I'd like to be able to access node metadata (task_id, node_name, and other arbitrary parameters) from within my node functions. My particular use case is writing an async compatible MlFlowTracker plugin. (I have had intermittent sucess with the the mlflow Fluent API trace decorator and would rather use the more explicit client API.)
From within a node function I'd like a way to access context passed in from plugin hooks.
Approach 1 - contextvars Working with @elijahbenizzy, I tried contextvars but found that the context wasn't maintained between the pre_node_execute hook, node function and the post_node_execute hook. See reproduction here: https://github.com/dmexs/hamilton-contextvar-repro
Approach 2 - modifying kwargs and parameter drilling Instead I took the approach of injecting a context object into node kwargs which works but requires parameter/prop drilling which I don't love. I'd prefer to be able to retrieve passed data from a contextvar anywhere within the node function execution stack. In my case I'm retrieving my mlflow span/request ids to create subspans on API calls.
Here's what I have for approach #2 which works:
mlflow_tracker.py
from hamilton.lifecycle import base
from hamilton import node, graph
import logging
from typing import Type, Any, Optional, Dict, List
import mlflow
import pickle
import contextvars
from pydantic import BaseModel
from contextlib import asynccontextmanager
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class NodeContext(BaseModel):
mlflow_request_id: Optional[str] = None
mlflow_root_span_id: Optional[str] = None
class MLFlowTrackerAsync(
base.BasePreNodeExecuteAsync,
base.BasePostNodeExecuteAsync,
base.BasePreGraphExecuteAsync,
base.BasePostGraphExecuteAsync,
):
mlflow_run_info = None
mlflow_client: Optional[mlflow.MlflowClient] = None
mlflow_request_id: Optional[str] = None
mlflow_root_span_id: Optional[str] = None
mlflow_node_span_id_cache = {}
def __init__(self, *args, **kwargs):
if mlflow.active_run():
self.mlflow_client = mlflow.MlflowClient()
self.mlflow_run_info = mlflow.active_run().info
logger.debug(f"MLFlowTrackerAsync: {self.mlflow_run_info.run_id}")
async def pre_graph_execute(
self,
*,
run_id: str,
graph: "graph.FunctionGraph",
final_vars: List[str],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
):
if self.mlflow_run_info:
trace = self.mlflow_client.start_trace(
name=f'hamilton',
inputs=inputs,
experiment_id=self.mlflow_run_info.experiment_id,
span_type=mlflow.entities.SpanType.CHAIN,
)
self.mlflow_request_id = trace.request_id
self.mlflow_root_span_id = trace.span_id
async def post_graph_execute(
self,
*,
run_id: str,
graph: "graph.FunctionGraph",
success: bool,
error: Optional[Exception],
results: Optional[Dict[str, Any]],
):
if self.mlflow_request_id:
self.mlflow_client.end_trace(
request_id=self.mlflow_request_id,
outputs=results
)
async def pre_node_execute(
self,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None
) -> None:
context = kwargs.get('__context', NodeContext())
if self.mlflow_request_id and self.mlflow_root_span_id:
span = self.mlflow_client.start_span(
name=node_.name,
span_type=mlflow.entities.SpanType.CHAIN,
request_id=self.mlflow_request_id,
parent_id=self.mlflow_root_span_id,
inputs=kwargs,
)
self.mlflow_node_span_id_cache[task_id] = span.span_id
context.mlflow_request_id = span.request_id
context.mlflow_root_span_id = span.span_id
if '__context' in node_.input_types:
kwargs['__context'] = context
async def post_node_execute(
self,
run_id: str,
node_: node.Node,
success: bool,
error: Optional[Exception],
result: Any,
task_id: Optional[str] = None,
**future_kwargs: dict,
) -> None:
if self.mlflow_node_span_id_cache.get(task_id, None):
self.mlflow_client.end_span(
request_id=self.mlflow_request_id,
span_id=self.mlflow_node_span_id_cache[task_id],
outputs=result
)
dag.py
async def patient_fhir_id(hero_api: hero_api_async.HeroAPI, uid: str, __context: Optional[NodeContext] = None) -> str:
"""
Lookup patient
"""
async with api as client:
patient = await client.request(endpoint='fhir/patient', params={"id": uid, "idType": "UID"}, context=__context)
fhir_id = patient['id']
return fhir_id
client.py
async def request(
self,
endpoint,
method = 'GET',
params: Optional[dict] = None,
context: Optional[NodeContext] = None
):
span = None
return_response = None
if context and context.mlflow_request_id and context.mlflow_root_span_id:
span = self.mlflow_client.start_span(
name='API.request',
span_type=mlflow.entities.SpanType.RETRIEVER,
request_id=context.mlflow_request_id,
parent_id=context.mlflow_root_span_id,
inputs={
"method": method,
"endpoint": endpoint,
"params": params
}
)
try:
url = self.BASE_URL + endpoint
async with self.session.request(
method=method,
url=url,
params=params,
) as response:
response.raise_for_status() # Raise an HTTPError for bad responses
return_response = await response.json(content_type=None) # Return the response data as a dictionary
return return_response
except Exception as e:
logger.error(f"{method} {response.status} {url} - {e}")
return_response = e
raise e
finally:
if span:
self.mlflow_client.end_span(
request_id=span.request_id,
span_id=span.span_id,
outputs=return_response
)
OK, I want to mirror the way burr does this:
https://burr.dagworks.io/reference/application/#burr.core.application.ApplicationContext
Any node can declare __context and have it passed in as the current context variable.
Currently I'm thinking of the following schema:
@dataclasses.dataclass
class Context:
node_id: str
task_id: str
run_id: str
This would be declared as __context in a fn:
def foo(bar: str, __context: DataflowContext) -> ...:
assert __context.node_id == "foo"
...
Looks good!
I'd love to see a working contextvar approach that doesn't require passing the parameter deeply into nested functions that need it, but this works just fine too!
Yep! I think we can explore that as well (a little bit of odd stuff around parallelism/async, but I need to scope it out).