airflow icon indicating copy to clipboard operation
airflow copied to clipboard

Enable workload identity authentication for the Databricks provider

Open basvandriel opened this issue 1 year ago • 1 comments

This pull request adds support for authenticating using Workload Identity for the Databricks provider. In the provided unit tests, this has been tested using mocking the token generation. However, the DefaultAzureCredential().get_token(...) part has been tested on an actual AKS cluster. This closes #41586.


basvandriel avatar Aug 21 '24 11:08 basvandriel

Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contributors' Guide (https://github.com/apache/airflow/blob/main/contributing-docs/README.rst) Here are some useful points:

  • Pay attention to the quality of your code (ruff, mypy and type annotations). Our pre-commits will help you with that.
  • In case of a new feature add useful documentation (in docstrings or in docs/ directory). Adding a new operator? Check this short guide Consider adding an example DAG that shows how users should use it.
  • Consider using Breeze environment for testing locally, it's a heavy docker but it ships with a working Airflow and a lot of integrations.
  • Be patient and persistent. It might take some time to get a review or get the final approval from Committers.
  • Please follow ASF Code of Conduct for all communication including (but not limited to) comments on Pull Requests, Mailing list and Slack.
  • Be sure to read the Airflow Coding style.
  • Always keep your Pull Requests rebased, otherwise your build might fail due to changes not related to your commits. Apache Airflow is a community-driven project and together we are making it better 🚀. In case of doubts contact the developers at: Mailing List: [email protected] Slack: https://s.apache.org/airflow-slack

boring-cyborg[bot] avatar Aug 21 '24 11:08 boring-cyborg[bot]

Looks okay to me, although there's quite a lot of code duplication between the _get_aad_token_for_default_az_credential and _a _get_aad_token_for_default_az_credential methods. Any way to extract common code?

BasPH avatar Sep 03 '24 09:09 BasPH

This duplication is essentially following the existing pattern that's already in place in the hook (if you look at the other methods, there is very often both a sync and an async implementation).

To reduce duplication we could break out the setup (lines 430-434) and the checking of the token (lines 450-456) into separate functions but I'm not sure if that would really improve readability.

A more heavy-weight alternative would be to only define the async version and then define the sync version using asyncio.run:

def _get_aad_token_for_default_az_credential(self, resource: str) -> str:
    return asyncio.run(self._a_get_aad_token_for_default_az_credential(resource)

However, this would add some overhead to the sync function call as asyncio.run starts an async event loop in the background:

This function runs the passed coroutine, taking care of managing the asyncio event loop, finalizing asynchronous generators, and closing the executor.

In short, to match the existing style I would prefer leaving it as is. However, I'm happy to explore these other options if needed. Also open to any suggestions from others :)

jrderuiter avatar Sep 05 '24 07:09 jrderuiter

@basvandriel Can you give me write access to your PR fork + branch? Then I can take over this PR as discussed privately.

jrderuiter avatar Sep 05 '24 07:09 jrderuiter

Hi @jrderuiter, I was already working on this. I will finalize this today. Keep you posted!

basvandriel avatar Sep 05 '24 07:09 basvandriel

@jrderuiter In that case keeping the duplicated code is okay with me. Could you fix the static code checks?

BasPH avatar Sep 05 '24 07:09 BasPH

@jrderuiter @BasPH

While I'm running my pre-commit checks I made some update on the duplicated code. Unfortunately, using asyncio.run is not an option. The code that is currently is being used uses a different SDK for the asynchronous part of the token generation.

However, I think the readability increases by having a Callable argument, where a function returns an AccessToken. This way we only need have 2 functions calling the token retrieving functions.

basvandriel avatar Sep 05 '24 09:09 basvandriel

I'm not really sure about the refactoring TBH, now it feels like we add an extra layer of complexity (the executor functions + get_aad_token functions) with minimal reduction in duplication.

You are right. If we want a similar approach, we would need 2 functions for the wrapping code and the rest for the variations of retrieving the token. I'm going to revert my commit and apply the other suggestions. Thanks!

basvandriel avatar Sep 05 '24 09:09 basvandriel

I see some of the static checks are failing, can you have a look at this @basvandriel? Afterwards we should be fine to merge if @BasPH agrees.

jrderuiter avatar Sep 09 '24 11:09 jrderuiter

@basvandriel Do you need any help fixing the static checks?

jrderuiter avatar Sep 13 '24 10:09 jrderuiter

@basvandriel Do you need any help fixing the static checks?

Hi @jrderuiter, I'm sorry for the late reply. I will be looking into this today.

basvandriel avatar Sep 16 '24 08:09 basvandriel

Hi @basvandriel almost looks okay to me. There's one tiny static check nit that thinks that "aks" is a typo: https://github.com/apache/airflow/actions/runs/10882398151/job/30464268761?pr=41639#step:8:80. Could you add "aks" to docs/spelling_wordlist.txt?

The other CICD failure (https://github.com/apache/airflow/actions/runs/10882398151/job/30464269767?pr=41639) is one I cannot explain at first sight. Will need to deepdive to understand what's failing there.

BasPH avatar Sep 23 '24 12:09 BasPH

@basvandriel @jrderuiter There's something fishy going on with the async test, see error in https://github.com/apache/airflow/actions/runs/11039010730/job/30937844742?pr=41639.

This fails:

@pytest.mark.asyncio
@mock.patch.dict(
    os.environ,
    {
        "AZURE_CLIENT_ID": "fake-client-id",
        "AZURE_TENANT_ID": "fake-tenant-id",
        "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
        "KUBERNETES_SERVICE_HOST": "fakeip",
    },
)
@mock.patch(
    "azure.identity.aio.DefaultAzureCredential.get_token", return_value=create_aad_token_for_resource()
)
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
    requests_mock.return_value.__aenter__.return_value.json.side_effect = mock.AsyncMock(
        side_effect=[{"data": 1}]
    )

    async with self._hook:
        result = await self._hook.a_get_run_output(0)

    assert result == {"data": 1}

I found that moving the mocking of env vars inside the test works, but I'm not sure why yet:

@pytest.mark.asyncio
@mock.patch(
    "azure.identity.aio.DefaultAzureCredential.get_token", return_value=create_aad_token_for_resource()
)
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
    with mock.patch.dict(
        os.environ,
        {
            "AZURE_CLIENT_ID": "fake-client-id",
            "AZURE_TENANT_ID": "fake-tenant-id",
            "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
            "KUBERNETES_SERVICE_HOST": "fakeip",
        },
    ):
        requests_mock.return_value.__aenter__.return_value.json.side_effect = mock.AsyncMock(
            side_effect=[{"data": 1}]
        )

        async with self._hook:
            result = await self._hook.a_get_run_output(0)

        assert result == {"data": 1}

BasPH avatar Oct 21 '24 09:10 BasPH

All green now, it took a while but thanks for hanging on @basvandriel @jrderuiter!

BasPH avatar Nov 06 '24 17:11 BasPH

Awesome work, congrats on your first merged pull request! You are invited to check our Issue Tracker for additional contributions.

boring-cyborg[bot] avatar Nov 06 '24 17:11 boring-cyborg[bot]