bilevel_coresets
bilevel_coresets copied to clipboard
How to speed up the loop in build_with_representer_proxy_batch
Hi, thank you very much for providing the code! I've installed jax with CUDA, so now solving the kernel_fn is faster. However, the build_with_representer_proxy_batch is still quite slow, I assume it is due to the solve_bilevel_opt_representer_proxy, which requires calculating the implicit gradient, or is it because of something else? Is there a way to make the computation faster? Thank you!