openvino icon indicating copy to clipboard operation
openvino copied to clipboard

[Good First Issue][JAX FE]: Support jax.lax.iota operation for JAX

Open rkazants opened this issue 1 year ago • 6 comments

Context

OpenVINO component responsible for support of JAX/Flax models is called as JAX Frontend (JAX FE). JAX FE converts a JAX/Flax model represented by ClosedJAXpr graph object with operations from jax.lax opset to OpenVINO IR containing operations from OpenVINO opset.

In order to infer JAX/Flax models containing jax.lax.iota operation by OpenVINO, JAX FE needs to be extended with this operation support.

What needs to be done?

For jax.lax.iota operation support, you need to implement the corresponding loader into JAX FE op directory and to register it into the dictionary of Loaders. One loader is responsible for conversion (or decomposition) of one type of JAX operation.

Here is an example of loader implementation for jax.lax.reshape operation:

OutputVector translate_reshape(const NodeContext& context) {
    num_inputs_check(context, 1, 1);
    Output<Node> input = context.get_input(0);
    auto new_sizes = context.const_named_param<std::vector<int64_t>>("new_sizes");
    if (context.has_param("dimensions")) {
        auto dimensions = context.const_named_param<std::vector<int64_t>>("dimensions");
        // transpose the input first.
        auto permutation_node = std::make_shared<v0::Constant>(element::i64, Shape{dimensions.size()}, dimensions);
        input = std::make_shared<v1::Transpose>(input, permutation_node);
    }

    auto new_shape_node = std::make_shared<v0::Constant>(element::i64, Shape{new_sizes.size()}, new_sizes);
    Output<Node> res = std::make_shared<v1::Reshape>(input, new_shape_node, false);
    return {res};
};

In this example, translate_reshape expresses jax.lax.reshape using OpenVINO opset. Since jax.lax.reshape performs transposition and tensor reshaping according to JAX documentation, the resulted decomposition contains OpenVINO Transpose and Reshape operations. For Transpose and Reshape nodes, this conversion parses constant parameters dimensions to permute input tensor and new_size that is the target shape of the result.

Once you are done with implementation of the translator, you need to implement the corresponding layer tests test_iota.py and put it into layer_tests/jax_tests directory. Example how to run some layer test:

export TEST_DEVICE=CPU
export JAX_TRACE_MODE=JAXPR
export 
cd openvino/tests/layer_tests/jax_tests
pytest test_reshape.py

Example Pull Requests

  • https://github.com/openvinotoolkit/openvino/pull/26288
  • https://github.com/openvinotoolkit/openvino/pull/26254

Resources

Contact points

  • @openvinotoolkit/openvino-jax-frontend-maintainers
  • @rkazants in GitHub and Discord

Ticket

No response

rkazants avatar Sep 12 '24 19:09 rkazants

@rkazants i would like to be assigned this issue

muhd360 avatar Sep 13 '24 15:09 muhd360

@rkazants i would like to be assigned this issue

the task is yours, welcome:)

rkazants avatar Sep 13 '24 15:09 rkazants

thx

muhd360 avatar Sep 13 '24 16:09 muhd360

hi @rkazants is there any additional reading available??

muhd360 avatar Sep 15 '24 20:09 muhd360

hi @rkazants is there any additional reading available??

Hi @muhd360,

It is a simple operation that generates a range [0, n-1], where n - value of size. See https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.iota.html You can use Range operation from our opset. And finally, the resulted value should be casted to the required type.

See examples of decompositions for other operation here: https://github.com/openvinotoolkit/openvino/tree/master/src/frontends/jax/src/op

Best regards, Roman

rkazants avatar Sep 17 '24 20:09 rkazants

@rkazants HI ive got it just failing a few unit test will close by today hopefully

muhd360 avatar Sep 18 '24 07:09 muhd360

.take

kumar-sanjeeev avatar Nov 12 '24 20:11 kumar-sanjeeev

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

github-actions[bot] avatar Nov 12 '24 20:11 github-actions[bot]

@rkazants Hi,

I started working on this issue but need to confirm one thing. After I have implemented everything (i.e., the iota loader in jax/op, its registration, and the test cases), do I need to first run the build process again and then execute the commands you mentioned in the previous conversation? Or can I directly run the tests without rebuilding the code?

export TEST_DEVICE=CPU
export JAX_TRACE_MODE=JAXPR
export 
cd openvino/tests/layer_tests/jax_tests
pytest test_iota.py

kumar-sanjeeev avatar Nov 13 '24 12:11 kumar-sanjeeev

.take

11happy avatar Dec 22 '24 08:12 11happy

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

github-actions[bot] avatar Dec 22 '24 08:12 github-actions[bot]

This issue will be closed in a week because of 9 months of no activity.

github-actions[bot] avatar Sep 23 '25 00:09 github-actions[bot]

This issue was closed because it has been stalled for 9 months with no activity.

github-actions[bot] avatar Sep 30 '25 00:09 github-actions[bot]