jaxsim icon indicating copy to clipboard operation
jaxsim copied to clipboard

Impossible to run with JAX_DISABLE_JIT set to True model with zero dof

Open xela-95 opened this issue 8 months ago • 9 comments

Related issue on JAX: https://github.com/google/jax/issues/4668

The error is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11
      [7](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:7) integration_time = 0.001
      [9](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:9) representation = jaxsim.VelRepr.Mixed
---> [11](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11) data = js.data.JaxSimModelData.build(
     [12](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:12)     model=model,
     [13](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:13)     velocity_representation=representation,  # standard_gravity=7.0
     [14](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:14) )
     [15](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:15) # integrator = integrators.fixed_step.RungeKutta4SO3.build(
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:16) # integrator = integrators.fixed_step.ForwardEuler.build(
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:17) integrator = integrators.fixed_step.ForwardEulerSO3.build(
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:18)     dynamics=js.ode.wrap_system_dynamics_for_integration(
     [19](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:19)         model=model,
   (...)
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:25)     ),
     [26](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:26) )

File ~/repos/jaxsim/src/jaxsim/api/data.py:186, in JaxSimModelData.build(model, base_position, base_quaternion, joint_positions, base_linear_velocity, base_angular_velocity, joint_velocities, standard_gravity, contact, contacts_params, velocity_representation, time)
    [176](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:176) time_ns = (
    [177](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:177)     jnp.array(time * 1e9, dtype=jnp.uint64)
    [178](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:178)     if time is not None
    [179](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:179)     else jnp.array(0, dtype=jnp.uint64)
    [180](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:180) )
    [182](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:182) if isinstance(model.contact_model, SoftContacts):
    [183](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:183)     contacts_params = (
    [184](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:184)         contacts_params
    [185](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:185)         if contacts_params is not None
--> [186](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:186)         else js.contact.estimate_good_soft_contacts_parameters(
    [187](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:187)             model=model, standard_gravity=standard_gravity
    [188](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:188)         )
    [189](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:189)     )
    [190](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:190) else:
    [191](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:191)     contacts_params = model.contact_model.parameters

File ~/repos/jaxsim/src/jaxsim/api/contact.py:270, in estimate_good_soft_contacts_parameters(model, standard_gravity, static_friction_coefficient, number_of_active_collidable_points_steady_state, damping_ratio, max_penetration)
    [263](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:263)         return 2 * (W_pz_CoM - W_pz_C.min())
    [265](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:265)     return 2 * W_pz_CoM
    [267](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:267) max_δ = (
    [268](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:268)     max_penetration
    [269](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:269)     if max_penetration is not None
--> [270](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:270)     else 0.005 * estimate_model_height(model=model)
    [271](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:271) )
    [273](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:273) nc = number_of_active_collidable_points_steady_state
    [275](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:275) sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
    [276](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:276)     model=model,
    [277](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:277)     standard_gravity=standard_gravity,
   (...)
    [281](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:281)     damping_ratio=damping_ratio,
    [282](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:282) )

File ~/repos/jaxsim/src/jaxsim/api/contact.py:259, in estimate_good_soft_contacts_parameters.<locals>.estimate_model_height(model)
    [252](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:252) """"""
    [254](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:254) zero_data = js.data.JaxSimModelData.build(
    [255](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:255)     model=model,
    [256](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:256)     contacts_params=SoftContactsParams(),
    [257](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:257) )
--> [259](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:259) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
    [261](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:261) if model.floating_base():
    [262](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:262)     W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]

File ~/repos/jaxsim/src/jaxsim/api/com.py:29, in com_position(model, data)
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:16) """
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:17) Compute the position of the center of mass of the model.
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:18) 
   (...)
     [24](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:24)     The position of the center of mass of the model w.r.t. the world frame.
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:25) """
     [27](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:27) m = js.model.total_mass(model=model)
---> [29](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:29) W_H_L = js.model.forward_kinematics(model=model, data=data)
     [30](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:30) W_H_B = data.base_transform()
     [31](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:31) B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()

File ~/repos/jaxsim/src/jaxsim/api/model.py:441, in forward_kinematics(model, data)
    [427](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:427) @jax.jit
    [428](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:428) def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
    [429](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:429)     """
    [430](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:430)     Compute the SE(3) transforms from the world frame to the frames of all links.
    [431](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:431) 
   (...)
    [438](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:438)         The first axis is the link index.
    [439](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:439)     """
--> [441](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:441)     W_H_LL = jaxsim.rbda.forward_kinematics_model(
    [442](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:442)         model=model,
    [443](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:443)         base_position=data.base_position(),
    [444](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:444)         base_quaternion=data.base_orientation(dcm=False),
    [445](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:445)         joint_positions=data.joint_positions(model=model),
    [446](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:446)     )
    [448](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:448)     return jnp.atleast_3d(W_H_LL).astype(float)

File ~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78, in forward_kinematics_model(model, base_position, base_quaternion, joint_positions)
     [74](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:74)     W_X_i = W_X_i.at[i].set(W_X_i_i)
     [76](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:76)     return (W_X_i,), None
---> [78](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78) (W_X_i,), _ = jax.lax.scan(
     [79](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:79)     f=propagate_kinematics,
     [80](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:80)     init=propagate_kinematics_carry,
     [81](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:81)     xs=jnp.arange(start=1, stop=model.number_of_links()),
     [82](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:82) )
     [84](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:84) return jax.vmap(Adjoint.to_transform)(W_X_i)

    [... skipping hidden 1 frame]

File ~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231, in scan(f, init, xs, length, reverse, unroll, _split_transpose)
    [229](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:229) if config.disable_jit.value:
    [230](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:230)   if length == 0:
--> [231](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231)     raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
    [232](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:232)   carry = init
    [233](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:233)   ys = []

ValueError: zero-length scan is not supported in disable_jit() mode because the output type is unknown.

xela-95 avatar Jul 01 '24 09:07 xela-95