grasp_diffusion icon indicating copy to clipboard operation
grasp_diffusion copied to clipboard

Minor: `n_grasps`, `n_envs` and `batch` usage in sampling scripts

Open kuldeepbrd1 opened this issue 1 year ago • 1 comments

The scripts in scripts/sample do not use n_grasps from the cli arguments.

n_grasps and n_envs in not used. Instead it always generate number of grasps equal to batch size specified in get_approximated_grasp_diffusion_field(...), as here: https://github.com/TheCamusean/grasp_diffusion/blob/3a2cb1448270798435479ee7cf8d1fbd9d5127c5/scripts/sample/generate_partial_pointcloud_6d_grasp_poses.py#L28

Batchwise sampling of n_grasps would be nice. This also avoids CUDA/cpu memory errors when large n_grasps is high. Something like this in main(...):

if __name__ == "__main__":
    ...
    n_grasps = int(args.n_grasps)
    obj_id = int(args.obj_id)
    obj_class = args.obj_class

    batch_size = 10

    ## Set Model and Sample Generator ##
    P, mesh = sample_pointcloud(obj_id, obj_class)
    generator, model = get_approximated_grasp_diffusion_field(
        P, args, batch=batch_size, device=device
    )

    H_batches = []
    batches = int(np.ceil((n_grasps / batch_size)))
    for i in range(0, batches):
        H_batches.append(generator.sample())

    H = torch.concatenate(H_batches, 0)
    H[..., :3, -1] *= 1 / 8.0
    ...

and get_approximated_grasp_diffusion_field(...) changed to

def get_approximated_grasp_diffusion_field(p, args, batch=10, device="cpu"):
    model_params = args.model

    ## Load model
    model_args = {"device": device, "pretrained_model": model_params}
    model = load_model(model_args)

It's not super critical to add this to code, so I highlight here. (Also, happy to also create a pull request, if you require)

kuldeepbrd1 avatar Oct 29 '22 14:10 kuldeepbrd1