Ax icon indicating copy to clipboard operation
Ax copied to clipboard

[Bug]: Some of the compute_analyses cards fail when using models with cuda

Open VMLC-PV opened this issue 7 months ago • 7 comments

What happened?

The client.compute_analyses does not output all cards if the model runs with cuda. If I use the Modular BoTorch Interface and pass the "torch_device":torch.device("cuda" if torch.cuda.is_available() else "CPU") in the model_kwards the optimization runs just fine on the GPU but the compute_analyses fails for some of the cars. I tracked down the issue to the sobol_measures.py and derivative_measures.py, which generate Tensors in several places and pass them to the model without checking which device the model is running on. Note that if torch_device = CPU, everything works just fine.

I could hack my way into making it work by first checking the model device and then sending the tensors to the right device, but that might not be the cleanest way to do that.

Any tips on how to proceed cleanly?

Please provide a minimal, reproducible example of the unexpected behavior.

The following lines show how I solved this issue for now. Note that several parts of the code need to be updated that way to get it to run.

def input_function(x: Tensor) -> Tensor:
    with torch.no_grad():
        means, variances = [], []
        # Since we're only looking at mean & variance, we can freely
        # use mini-batches.
        x = x.to(next(self.model.parameters()).device)  # get x to the same device as the model <---- NEW
        for x_split in x.split(split_size=mini_batch_size):                    
            p = assert_is_instance(
                self.model.posterior(x_split),
                GPyTorchPosterior,
            )
            means.append(p.mean)
            variances.append(p.variance)

        cat_dim = 1 if is_ensemble(self.model) else 0
        return link_function(
            torch.cat(means, dim=cat_dim), torch.cat(variances, dim=cat_dim)
        )

Please paste any relevant traceback/logs produced by the example provided.


Ax Version

1.0.0

Python Version

3.13.2

Operating System

Ubuntu

(Optional) Describe any potential fixes you've considered to the issue outlined above.

No response

Pull Request

None

Code of Conduct

  • [x] I agree to follow Ax's Code of Conduct

VMLC-PV avatar May 22 '25 12:05 VMLC-PV

Thanks for flagging this. This is something we should fix on our end, the sensitivity analysis should work regardless of which device the input_function operates on even in the SobolSensitivity base class.

Right now it seems that we don't have enough information in the interface to know that though, unless we use the convention that the device that bounds lives on is the device that input_function expects its tensors from. Or we provide the device explicitly as an argument to the SobolSensitivity constructor.

In the meantime the change you made to SobolSensitivityGPMean seems like good short-term fix; would you mind putting up a PR?

Balandat avatar May 22 '25 13:05 Balandat

I'll try to do so at the end of next week!

VMLC-PV avatar May 22 '25 14:05 VMLC-PV

Hi folks, what's the latest on this issue? Just trying to understand if it's still active. @VMLC-PV , @Balandat

lena-kashtelyan avatar May 28 '25 20:05 lena-kashtelyan

It is I just haven't had the time yet to do a PR

VMLC-PV avatar May 28 '25 20:05 VMLC-PV

Perfect got it, I'm going to keep this around then! Appreciate the quick response : )

lena-kashtelyan avatar May 28 '25 21:05 lena-kashtelyan

OK, I have tried this #3833 This is the first time I've pushed to a PR to a repo that isn't mine, so hopefully I didn't mess anything up.

I ran the pytest on my machine, and a few failed, but they seem unrelated to what I added. The failed tests are TestPyTorchCNNTorchvision::test_deterministic , TestBenchmark::test_replication_mbm, and a few in TestAxClient. All of them also seem to be related to a mismatch between the tensors' devices. And I could not fix them. I am unsure whether this issue originates from my machine or the tests themselves.

VMLC-PV avatar May 29 '25 15:05 VMLC-PV

Alright, hopefully #3839 is fine. I reran the tests locally, and it is the same as before.

VMLC-PV avatar May 30 '25 07:05 VMLC-PV