airflow
airflow copied to clipboard
Enable workload identity authentication for the Databricks provider
This pull request adds support for authenticating using Workload Identity for the Databricks provider. In the provided unit tests, this has been tested using mocking the token generation. However, the DefaultAzureCredential().get_token(...) part has been tested on an actual AKS cluster. This closes #41586.
Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contributors' Guide (https://github.com/apache/airflow/blob/main/contributing-docs/README.rst) Here are some useful points:
- Pay attention to the quality of your code (ruff, mypy and type annotations). Our pre-commits will help you with that.
- In case of a new feature add useful documentation (in docstrings or in
docs/directory). Adding a new operator? Check this short guide Consider adding an example DAG that shows how users should use it. - Consider using Breeze environment for testing locally, it's a heavy docker but it ships with a working Airflow and a lot of integrations.
- Be patient and persistent. It might take some time to get a review or get the final approval from Committers.
- Please follow ASF Code of Conduct for all communication including (but not limited to) comments on Pull Requests, Mailing list and Slack.
- Be sure to read the Airflow Coding style.
- Always keep your Pull Requests rebased, otherwise your build might fail due to changes not related to your commits. Apache Airflow is a community-driven project and together we are making it better 🚀. In case of doubts contact the developers at: Mailing List: [email protected] Slack: https://s.apache.org/airflow-slack
Looks okay to me, although there's quite a lot of code duplication between the _get_aad_token_for_default_az_credential and _a _get_aad_token_for_default_az_credential methods. Any way to extract common code?
This duplication is essentially following the existing pattern that's already in place in the hook (if you look at the other methods, there is very often both a sync and an async implementation).
To reduce duplication we could break out the setup (lines 430-434) and the checking of the token (lines 450-456) into separate functions but I'm not sure if that would really improve readability.
A more heavy-weight alternative would be to only define the async version and then define the sync version using asyncio.run:
def _get_aad_token_for_default_az_credential(self, resource: str) -> str:
return asyncio.run(self._a_get_aad_token_for_default_az_credential(resource)
However, this would add some overhead to the sync function call as asyncio.run starts an async event loop in the background:
This function runs the passed coroutine, taking care of managing the asyncio event loop, finalizing asynchronous generators, and closing the executor.
In short, to match the existing style I would prefer leaving it as is. However, I'm happy to explore these other options if needed. Also open to any suggestions from others :)
@basvandriel Can you give me write access to your PR fork + branch? Then I can take over this PR as discussed privately.
Hi @jrderuiter, I was already working on this. I will finalize this today. Keep you posted!
@jrderuiter In that case keeping the duplicated code is okay with me. Could you fix the static code checks?
@jrderuiter @BasPH
While I'm running my pre-commit checks I made some update on the duplicated code. Unfortunately, using asyncio.run is not an option. The code that is currently is being used uses a different SDK for the asynchronous part of the token generation.
However, I think the readability increases by having a Callable argument, where a function returns an AccessToken. This way we only need have 2 functions calling the token retrieving functions.
I'm not really sure about the refactoring TBH, now it feels like we add an extra layer of complexity (the
executorfunctions +get_aad_tokenfunctions) with minimal reduction in duplication.
You are right. If we want a similar approach, we would need 2 functions for the wrapping code and the rest for the variations of retrieving the token. I'm going to revert my commit and apply the other suggestions. Thanks!
I see some of the static checks are failing, can you have a look at this @basvandriel? Afterwards we should be fine to merge if @BasPH agrees.
@basvandriel Do you need any help fixing the static checks?
@basvandriel Do you need any help fixing the static checks?
Hi @jrderuiter, I'm sorry for the late reply. I will be looking into this today.
Hi @basvandriel almost looks okay to me. There's one tiny static check nit that thinks that "aks" is a typo: https://github.com/apache/airflow/actions/runs/10882398151/job/30464268761?pr=41639#step:8:80. Could you add "aks" to docs/spelling_wordlist.txt?
The other CICD failure (https://github.com/apache/airflow/actions/runs/10882398151/job/30464269767?pr=41639) is one I cannot explain at first sight. Will need to deepdive to understand what's failing there.
@basvandriel @jrderuiter There's something fishy going on with the async test, see error in https://github.com/apache/airflow/actions/runs/11039010730/job/30937844742?pr=41639.
This fails:
@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
{
"AZURE_CLIENT_ID": "fake-client-id",
"AZURE_TENANT_ID": "fake-tenant-id",
"AZURE_FEDERATED_TOKEN_FILE": "/badpath",
"KUBERNETES_SERVICE_HOST": "fakeip",
},
)
@mock.patch(
"azure.identity.aio.DefaultAzureCredential.get_token", return_value=create_aad_token_for_resource()
)
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
requests_mock.return_value.__aenter__.return_value.json.side_effect = mock.AsyncMock(
side_effect=[{"data": 1}]
)
async with self._hook:
result = await self._hook.a_get_run_output(0)
assert result == {"data": 1}
I found that moving the mocking of env vars inside the test works, but I'm not sure why yet:
@pytest.mark.asyncio
@mock.patch(
"azure.identity.aio.DefaultAzureCredential.get_token", return_value=create_aad_token_for_resource()
)
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
with mock.patch.dict(
os.environ,
{
"AZURE_CLIENT_ID": "fake-client-id",
"AZURE_TENANT_ID": "fake-tenant-id",
"AZURE_FEDERATED_TOKEN_FILE": "/badpath",
"KUBERNETES_SERVICE_HOST": "fakeip",
},
):
requests_mock.return_value.__aenter__.return_value.json.side_effect = mock.AsyncMock(
side_effect=[{"data": 1}]
)
async with self._hook:
result = await self._hook.a_get_run_output(0)
assert result == {"data": 1}
All green now, it took a while but thanks for hanging on @basvandriel @jrderuiter!
Awesome work, congrats on your first merged pull request! You are invited to check our Issue Tracker for additional contributions.