grasp_diffusion
grasp_diffusion copied to clipboard
Minor: `n_grasps`, `n_envs` and `batch` usage in sampling scripts
The scripts in scripts/sample
do not use n_grasps from the cli arguments.
n_grasps
and n_envs
in not used. Instead it always generate number of grasps equal to batch size specified in get_approximated_grasp_diffusion_field(...)
, as here:
https://github.com/TheCamusean/grasp_diffusion/blob/3a2cb1448270798435479ee7cf8d1fbd9d5127c5/scripts/sample/generate_partial_pointcloud_6d_grasp_poses.py#L28
Batchwise sampling of n_grasps
would be nice. This also avoids CUDA/cpu memory errors when large n_grasps
is high. Something like this in main(...)
:
if __name__ == "__main__":
...
n_grasps = int(args.n_grasps)
obj_id = int(args.obj_id)
obj_class = args.obj_class
batch_size = 10
## Set Model and Sample Generator ##
P, mesh = sample_pointcloud(obj_id, obj_class)
generator, model = get_approximated_grasp_diffusion_field(
P, args, batch=batch_size, device=device
)
H_batches = []
batches = int(np.ceil((n_grasps / batch_size)))
for i in range(0, batches):
H_batches.append(generator.sample())
H = torch.concatenate(H_batches, 0)
H[..., :3, -1] *= 1 / 8.0
...
and get_approximated_grasp_diffusion_field(...)
changed to
def get_approximated_grasp_diffusion_field(p, args, batch=10, device="cpu"):
model_params = args.model
## Load model
model_args = {"device": device, "pretrained_model": model_params}
model = load_model(model_args)
It's not super critical to add this to code, so I highlight here. (Also, happy to also create a pull request, if you require)
Thanks @kuldeepbrd1 .
You are right! This would be highly benefitial to sample properly the grasps. If you feel up to it, create a pull request, I tested it and if everything works smoothly, I accept it.
Thanks alot :)