brax
brax copied to clipboard
Joints and Actuators are Not Aligned
Here's my humanoid character: https://pastebin.com/7Ym4aqy6
Now, borrowing from the existing environment code, if I use something like
joint_1d_angle, joint_1d_vel = self.sys.joints[0].angle_vel(qp)
joint_2d_angle, joint_2d_vel = self.sys.joints[1].angle_vel(qp)
joint_3d_angle, joint_3d_vel = self.sys.joints[2].angle_vel(qp)
joint_angles = jnp.concatenate([
joint_1d_angle[0],
joint_2d_angle[0],
joint_2d_angle[1],
joint_3d_angle[0],
joint_3d_angle[1],
joint_3d_angle[2]]
)
then I get an array of size self.sys.num_joint_dof
. Now, the actuation (torque) passed into self.sys.step()
has the same size. However, I notice that actuators and joints are not gathered with a consistent indexing. For example, in my environment, act[-1]
is the actuation of left_elbow, but joint_angles[-1]
is not the angle of left_elbow.
I suspect this has to do with how indices are separately computed for joints and actuators. It seems that the actuators are sorted by the order of appearance in my config, while the joints are sorted by degrees of freedom (not sure what the suborder is... by appearance?). This essentially prevents me from applying a PD controller.
In actuators.py:
In joints.py:
Current workaround is to use self.sys.actuator[0].act_index
and so on, concatenate them, and then sort joint_angles
, such that the joint indices match the actuator indices. However, it would be much more convenient if joints and actuators were perfectly aligned by default.
Elaborating more on the workaround: something like this works for now,
def get_joint_angles(self, qp: QP):
"""
Return joint angles in the same order as actuator order.
"""
joint_1d_act_index = self.sys.actuators[0].act_index
joint_2d_act_index = self.sys.actuators[1].act_index
joint_3d_act_index = self.sys.actuators[2].act_index
joint_1d_angle, joint_1d_vel = self.sys.joints[0].angle_vel(qp)
joint_2d_angle, joint_2d_vel = self.sys.joints[1].angle_vel(qp)
joint_3d_angle, joint_3d_vel = self.sys.joints[2].angle_vel(qp)
joint_1d_angle = jnp.stack(joint_1d_angle)
joint_2d_angle = jnp.stack(joint_2d_angle).T
joint_3d_angle = jnp.stack(joint_3d_angle).T
ret = jnp.zeros(self.sys.num_joint_dof)
ret = ret.at[joint_1d_act_index.reshape(-1)].set(joint_1d_angle.reshape(-1))
ret = ret.at[joint_2d_act_index.reshape(-1)].set(joint_2d_angle.reshape(-1))
ret = ret.at[joint_3d_act_index.reshape(-1)].set(joint_3d_angle.reshape(-1))
return ret