warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] Using warp.array in warp.struct may cause cuda memory error.

Open matafela opened this issue 3 months ago • 1 comments

Bug Description

Using warp.array in warp.struct may cause cuda memory error. (warp version 1.8.0) Code:

import warp as wp
import numpy as np
from typing import Any
from typing import Tuple



@wp.func
def safe_acos(x: float) -> float:
    return wp.acos(wp.clamp(x, -1.0, 1.0))


@wp.func
def th4_th6_for_branch(
    i: int,
    r_: wp.mat33f,
    sin1: wp.vec4f,
    cos1: wp.vec4f,
    s23: wp.vec4f,
    c23: wp.vec4f,
) -> Tuple[float, float]:
    th4_y = r_[1, 2] * cos1[i] - r_[0, 2] * sin1[i]
    th4_x = (
        r_[0, 2] * c23[i] * cos1[i] + r_[1, 2] * c23[i] * sin1[i] - r_[2, 2] * s23[i]
    )
    th4 = wp.atan2(th4_y, th4_x)

    th6_y = (
        r_[0, 1] * s23[i] * cos1[i] + r_[1, 1] * s23[i] * sin1[i] + r_[2, 1] * c23[i]
    )
    th6_x = (
        -r_[0, 0] * s23[i] * cos1[i] - r_[1, 0] * s23[i] * sin1[i] - r_[2, 0] * c23[i]
    )
    th6 = wp.atan2(th6_y, th6_x)
    return th4, th6


@wp.struct
class OPWparam:
    a1: float
    a2: float
    b: float
    c1: float
    c2: float
    c3: float
    c4: float
    offsets: wp.array(dtype=float)
    sign_corrections: wp.array(dtype=float)


@wp.func
def get_transform_err(
    transform1: wp.mat44f, transform2: wp.mat44f
) -> Tuple[float, float]:
    t_diff = wp.vec3f(
        transform1[0, 3] - transform2[0, 3],
        transform1[1, 3] - transform2[1, 3],
        transform1[2, 3] - transform2[2, 3],
    )
    t_err = wp.length(t_diff)
    r1 = wp.mat33f(
        transform1[0, 0],
        transform1[0, 1],
        transform1[0, 2],
        transform1[1, 0],
        transform1[1, 1],
        transform1[1, 2],
        transform1[2, 0],
        transform1[2, 1],
        transform1[2, 2],
    )
    r2 = wp.mat33f(
        transform2[0, 0],
        transform2[0, 1],
        transform2[0, 2],
        transform2[1, 0],
        transform2[1, 1],
        transform2[1, 2],
        transform2[2, 0],
        transform2[2, 1],
        transform2[2, 2],
    )
    r_diff = wp.transpose(r1) * r2
    cos_value = 0.5 * (wp.trace(r_diff) - 1.0)
    r_err = wp.abs(safe_acos(cos_value))
    return t_err, r_err


@wp.func
def opw_single_fk(
    q1: float,
    q2: float,
    q3: float,
    q4: float,
    q5: float,
    q6: float,
    params: OPWparam,
):
    psi3 = wp.atan2(params.a2, params.c3)
    k = wp.sqrt(params.a2 * params.a2 + params.c3 * params.c3)

    # Precompute q23_psi3 for better readability and reuse
    q23_psi3 = q2 + q3 + psi3
    sin_q23_psi3 = wp.sin(q23_psi3)
    cos_q23_psi3 = wp.cos(q23_psi3)

    cx1 = params.c2 * wp.sin(q2) + k * sin_q23_psi3 + params.a1
    cy1 = params.b
    cz1 = params.c2 * wp.cos(q2) + k * cos_q23_psi3

    cx0 = cx1 * wp.cos(q1) - cy1 * wp.sin(q1)
    cy0 = cx1 * wp.sin(q1) + cy1 * wp.cos(q1)
    cz0 = cz1 + params.c1

    s1, c1 = wp.sin(q1), wp.cos(q1)
    s2, c2 = wp.sin(q2), wp.cos(q2)
    s3, c3 = wp.sin(q3), wp.cos(q3)
    s4, c4 = wp.sin(q4), wp.cos(q4)
    s5, c5 = wp.sin(q5), wp.cos(q5)
    s6, c6 = wp.sin(q6), wp.cos(q6)

    r_0c = wp.mat33f(
        c1 * c2 * c3 - c1 * s2 * s3,
        -s1,
        c1 * c2 * s3 + c1 * s2 * c3,
        s1 * c2 * c3 - s1 * s2 * s3,
        c1,
        s1 * c2 * s3 + s1 * s2 * c3,
        -s2 * c3 - c2 * s3,
        0.0,
        -s2 * s3 + c2 * c3,
    )
    r_ce = wp.mat33f(
        c4 * c5 * c6 - s4 * s6,
        -c4 * c5 * s6 - s4 * c6,
        c4 * s5,
        s4 * c5 * c6 + c4 * s6,
        -s4 * c5 * s6 + c4 * c6,
        s4 * s5,
        -s5 * c6,
        s5 * s6,
        c5,
    )

    r_0e = r_0c * r_ce
    t_0e = wp.vec3f(
        cx0 + params.c4 * r_0e[0, 2],
        cy0 + params.c4 * r_0e[1, 2],
        cz0 + params.c4 * r_0e[2, 2],
    )

    return wp.mat44f(
        r_0e[0, 0],
        r_0e[0, 1],
        r_0e[0, 2],
        t_0e[0],
        r_0e[1, 0],
        r_0e[1, 1],
        r_0e[1, 2],
        t_0e[1],
        r_0e[2, 0],
        r_0e[2, 1],
        r_0e[2, 2],
        t_0e[2],
        0.0,
        0.0,
        0.0,
        1.0,
    )


@wp.kernel
def opw_fk_kernel(
    qpos: wp.array(dtype=float),
    ee_pose: wp.mat44f,
    params: OPWparam,
    xpos: wp.array(dtype=float),
):
    i = wp.tid()
    dof = 6
    q1 = qpos[0 + i * dof] * params.sign_corrections[0] + params.offsets[0]
    q2 = qpos[1 + i * dof] * params.sign_corrections[1] + params.offsets[1]
    q3 = qpos[2 + i * dof] * params.sign_corrections[2] + params.offsets[2]
    q4 = qpos[3 + i * dof] * params.sign_corrections[3] + params.offsets[3]
    q5 = qpos[4 + i * dof] * params.sign_corrections[4] + params.offsets[4]
    q6 = qpos[5 + i * dof] * params.sign_corrections[5] + params.offsets[5]

    p_0e = opw_single_fk(q1, q2, q3, q4, q5, q6, params)
    result = p_0e * ee_pose

    # assign to result
    for t in range(16):
        xpos[t + i * 16] = result[t // 4, t % 4]


wp_vec48f = wp.types.vector(length=48, dtype=float)

@wp.kernel
def opw_ik_kernel(
    xpos: wp.array(dtype=float),
    ee_pose_inv: wp.mat44f,
    params: OPWparam,
    qpos: wp.array(dtype=float),
    ik_valid: wp.array(dtype=bool)
):
    i = wp.tid()
    # TODO: warp slice ?
    ee_pose = (
        wp.mat44f(
            xpos[i * 16 + 0],
            xpos[i * 16 + 1],
            xpos[i * 16 + 2],
            xpos[i * 16 + 3],
            xpos[i * 16 + 4],
            xpos[i * 16 + 5],
            xpos[i * 16 + 6],
            xpos[i * 16 + 7],
            xpos[i * 16 + 8],
            xpos[i * 16 + 9],
            xpos[i * 16 + 10],
            xpos[i * 16 + 11],
            xpos[i * 16 + 12],
            xpos[i * 16 + 13],
            xpos[i * 16 + 14],
            xpos[i * 16 + 15],
        )
        * ee_pose_inv
    )
    r_ = wp.mat33f(
        ee_pose[0, 0],
        ee_pose[0, 1],
        ee_pose[0, 2],
        ee_pose[1, 0],
        ee_pose[1, 1],
        ee_pose[1, 2],
        ee_pose[2, 0],
        ee_pose[2, 1],
        ee_pose[2, 2],
    )
    rz_ = wp.vec3f(ee_pose[0, 2], ee_pose[1, 2], ee_pose[2, 2])
    t_ = wp.vec3f(ee_pose[0, 3], ee_pose[1, 3], ee_pose[2, 3])

    # to wrist center position
    c = t_ - params.c4 * rz_

    r_xy2 = c[0] * c[0] + c[1] * c[1]
    nx1_sqrt_arg = r_xy2 - params.b * params.b
    nx1 = wp.sqrt(nx1_sqrt_arg) - params.a1

    tmp1 = wp.atan2(c[1], c[0])
    tmp2 = wp.atan2(params.b, nx1 + params.a1)
    theta1_i = tmp1 - tmp2
    theta1_ii = tmp1 + tmp2 - wp.pi

    tmp3 = c[2] - params.c1
    s1_2 = nx1 * nx1 + tmp3 * tmp3

    tmp4 = nx1 + 2.0 * params.a1
    s2_2 = tmp4 * tmp4 + tmp3 * tmp3
    kappa_2 = params.a2 * params.a2 + params.c3 * params.c3

    c2_2 = params.c2 * params.c2

    tmp5 = s1_2 + c2_2 - kappa_2
    s1 = wp.sqrt(s1_2)
    s2 = wp.sqrt(s2_2)

    # theta2
    tmp13 = safe_acos(tmp5 / (2.0 * s1 * params.c2))
    tmp14 = wp.atan2(nx1, c[2] - params.c1)
    theta2_i = -tmp13 + tmp14
    theta2_ii = tmp13 + tmp14

    tmp6 = s2_2 + c2_2 - kappa_2
    tmp15 = safe_acos(tmp6 / (2.0 * s2 * params.c2))
    tmp16 = wp.atan2(nx1 + 2.0 * params.a1, c[2] - params.c1)
    theta2_iii = -tmp15 - tmp16
    theta2_iv = tmp15 - tmp16

    # theta3
    tmp7 = s1_2 - c2_2 - kappa_2
    tmp8 = s2_2 - c2_2 - kappa_2
    tmp9 = 2.0 * params.c2 * wp.sqrt(kappa_2)
    tmp10 = wp.atan2(params.a2, params.c3)

    tmp11 = safe_acos(tmp7 / tmp9)
    theta3_i = tmp11 - tmp10
    theta3_ii = -tmp11 - tmp10

    tmp12 = safe_acos(tmp8 / tmp9)
    theta3_iii = tmp12 - tmp10
    theta3_iv = -tmp12 - tmp10

    # precompute sin/cos(theta1)
    theta1_i_sin = wp.sin(theta1_i)
    theta1_i_cos = wp.cos(theta1_i)
    theta1_ii_sin = wp.sin(theta1_ii)
    theta1_ii_cos = wp.cos(theta1_ii)

    sin1 = wp.vec4f(theta1_i_sin, theta1_i_sin, theta1_ii_sin, theta1_ii_sin)
    cos1 = wp.vec4f(theta1_i_cos, theta1_i_cos, theta1_ii_cos, theta1_ii_cos)
    s23 = wp.vec4f(
        wp.sin(theta2_i + theta3_i),
        wp.sin(theta2_ii + theta3_ii),
        wp.sin(theta2_iii + theta3_iii),
        wp.sin(theta2_iv + theta3_iv),
    )
    c23 = wp.vec4f(
        wp.cos(theta2_i + theta3_i),
        wp.cos(theta2_ii + theta3_ii),
        wp.cos(theta2_iii + theta3_iii),
        wp.cos(theta2_iv + theta3_iv),
    )

    # m for theta5
    m = wp.vec4f(
        r_[0, 2] * s23[0] * cos1[0] + r_[1, 2] * s23[0] * sin1[0] + r_[2, 2] * c23[0],
        r_[0, 2] * s23[1] * cos1[1] + r_[1, 2] * s23[1] * sin1[1] + r_[2, 2] * c23[1],
        r_[0, 2] * s23[2] * cos1[2] + r_[1, 2] * s23[2] * sin1[2] + r_[2, 2] * c23[2],
        r_[0, 2] * s23[3] * cos1[3] + r_[1, 2] * s23[3] * sin1[3] + r_[2, 2] * c23[3],
    )
    theta5 = wp.vec4f(
        wp.atan2(wp.sqrt(wp.clamp(1.0 - m[0] * m[0], 0.0, 1.0)), m[0]),
        wp.atan2(wp.sqrt(wp.clamp(1.0 - m[1] * m[1], 0.0, 1.0)), m[1]),
        wp.atan2(wp.sqrt(wp.clamp(1.0 - m[2] * m[2], 0.0, 1.0)), m[2]),
        wp.atan2(wp.sqrt(wp.clamp(1.0 - m[3] * m[3], 0.0, 1.0)), m[3]),
    )

    theta4_i, theta6_i = th4_th6_for_branch(0, r_, sin1, cos1, s23, c23)
    theta4_ii, theta6_ii = th4_th6_for_branch(1, r_, sin1, cos1, s23, c23)
    theta4_iii, theta6_iii = th4_th6_for_branch(2, r_, sin1, cos1, s23, c23)
    theta4_iv, theta6_iv = th4_th6_for_branch(3, r_, sin1, cos1, s23, c23)
    theta5_i, theta5_ii, theta5_iii, theta5_iv = (
        theta5[0],
        theta5[1],
        theta5[2],
        theta5[3],
    )
    theta5_v, theta5_vi, theta5_vii, theta5_viii = (
        -theta5_i,
        -theta5_ii,
        -theta5_iii,
        -theta5_iv,
    )

    theta4_v, theta4_vi, theta4_vii, theta4_viii = (
        theta4_i + wp.pi,
        theta4_ii + wp.pi,
        theta4_iii + wp.pi,
        theta4_iv + wp.pi,
    )
    theta6_v, theta6_vi, theta6_vii, theta6_viii = (
        theta6_i - wp.pi,
        theta6_ii - wp.pi,
        theta6_iii - wp.pi,
        theta6_iv - wp.pi,
    )
    # combine all 8 solutions
    theta = wp_vec48f(
        theta1_i,
        theta2_i,
        theta3_i,
        theta4_i,
        theta5_i,
        theta6_i,
        theta1_i,
        theta2_ii,
        theta3_ii,
        theta4_ii,
        theta5_ii,
        theta6_ii,
        theta1_ii,
        theta2_iii,
        theta3_iii,
        theta4_iii,
        theta5_iii,
        theta6_iii,
        theta1_ii,
        theta2_iv,
        theta3_iv,
        theta4_iv,
        theta5_iv,
        theta6_iv,
        theta1_i,
        theta2_i,
        theta3_i,
        theta4_v,
        theta5_v,
        theta6_v,
        theta1_i,
        theta2_ii,
        theta3_ii,
        theta4_vi,
        theta5_vi,
        theta6_vi,
        theta1_ii,
        theta2_iii,
        theta3_iii,
        theta4_vii,
        theta5_vii,
        theta6_vii,
        theta1_ii,
        theta2_iv,
        theta3_iv,
        theta4_viii,
        theta5_viii,
        theta6_viii,
    )
    DOF = 6
    N_SOL = 8
    # apply sign correction and offsets, and write to qpos
    for j in range(N_SOL):
        qpos_start = i * DOF * N_SOL + j * DOF

        for k in range(DOF):
            idx = j * DOF + k
            qpos[qpos_start + k] = (
                theta[idx] * params.sign_corrections[k] - params.offsets[k]
            )

        # filter invalid solutions
        check_ee_pose = opw_single_fk(
            qpos[qpos_start + 0],
            qpos[qpos_start + 1],
            qpos[qpos_start + 2],
            qpos[qpos_start + 3],
            qpos[qpos_start + 4],
            qpos[qpos_start + 5],
            params,
        )
        t_err, r_err = get_transform_err(check_ee_pose, ee_pose)
        if t_err > 1e-3 or r_err > 1e-2:
            ik_valid[i * N_SOL + j] = False
        else:
            ik_valid[i * N_SOL + j] = True


# TODO: remove this before pr merged
if __name__ == "__main__":
    device = "cuda"
    # init params
    params = OPWparam()
    params.a1 = 0.400333
    params.a2 = -0.251449
    params.b = 0.0
    params.c1 = 0.8300
    params.c2 = 1.177556
    params.c3 = 1.443593
    params.c4 = 0.2300
    params.offsets = wp.array(np.array([0.0, 0, 0, 0, 0, 0]), dtype=float, device=device)
    flip_axes = (True, False, True, True, False, True)
    sign_corrections = [-1.0 if flip else 1.0 for flip in flip_axes]
    params.sign_corrections = wp.array(np.array(sign_corrections), dtype=float, device=device)

    test_qposes = np.array(
        [
            [0, 0, -np.pi / 2, 0, 0, 0],
            [np.pi / 18, 0, -np.pi / 2, 0, 0, 0],
            [0, np.pi / 18, -np.pi / 2, 0, 0, 0],
            [0, 0, -4 * np.pi / 9, 0, 0, 0],
            [0, 0, -np.pi / 2, np.pi / 18, 0, 0],
            [0, 0, -np.pi / 2, 0, np.pi / 18, 0],
            [0, 0, -np.pi / 2, 0, 0, np.pi / 18],
            [np.pi / 18, np.pi / 9, -np.pi / 2, np.pi / 6, np.pi / 9, np.pi / 18],
        ]
    )
    # test_qposes = np.random.uniform(size=(10, 6), low=-np.pi, high=np.pi)

    ee_pose = np.array(
        [
            [0, 0, -1, 0],
            [0, 1, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
        ],
        dtype=float,
    )
    ee_pose_inv = np.eye(4, dtype=float)
    ee_pose_inv[:3, :3] = ee_pose[:3, :3].T
    ee_pose_inv[:3, 3] = -ee_pose[:3, :3].T @ ee_pose[:3, 3]

    qpos_wp = wp.array(test_qposes.flatten(), dtype=float, device=device)
    xpos_wp = wp.zeros(test_qposes.shape[0] * 16, dtype=float, device=device)
    ee_pose_wp = wp.mat44f(ee_pose)
    n_sample = test_qposes.shape[0]
    wp.launch(
        kernel=opw_fk_kernel,
        dim=(test_qposes.shape[0]),
        inputs=[
            qpos_wp,
            ee_pose_wp,
            params,
        ],
        outputs=[
            xpos_wp,
        ],
        device="cuda",
    )

    xpos_wp2 = wp.array(xpos_wp, requires_grad=True, device=device)
    solution_qpos_wp = wp.zeros(n_sample * 8 * 6, dtype=float, device=device)
    solution_valid_wp = wp.zeros(n_sample * 8, dtype=bool, device=device)
    ee_pose_inv_wp = wp.mat44f(ee_pose_inv)
    wp.launch(
        kernel=opw_ik_kernel,
        dim=(test_qposes.shape[0]),
        inputs=[
            xpos_wp2,
            ee_pose_inv_wp,
            params,
        ],
        outputs=[
            solution_qpos_wp,
            solution_valid_wp
        ],
        device="cuda",
    )
    solution_qpos = solution_qpos_wp.numpy().reshape(n_sample, 8, 6)
    solution_valid = solution_valid_wp.numpy().reshape(n_sample, 8)
    print("solution_qpos:", solution_qpos)
    print("solution_valid:", solution_valid)

When i remove offsets and sign_corrections from OPWparam, the crash does not exist.

System Information

No response

matafela avatar Oct 17 '25 10:10 matafela

Thanks for sharing this issue @matafela. We're still pondering it.

shi-eric avatar Oct 20 '25 22:10 shi-eric