Max Balandat

Results 476 comments of Max Balandat

I looked into this and hacked my way around setting a `ModelList` (basically just locally had this function return `True`): https://github.com/facebook/Ax/blob/ff1445aa089f47afd86079057689970d19d85cd6/ax/models/torch/botorch_modular/utils.py#L45 However the next problem (and that seems more serious)...

> The initialization of the BaseTestProblem class should not enforce double precision in the buffer for the bounds but probably use torch.get_default_dtype() Yep, that makes sense. If `self._bounds` are python...

Great. Btw, have you been using BoTorch more generally with MPS on Mac? I have tried this in the past but a lot of operations that we (and gpytorch) use...

> It seems like there are several places where there is an implicit assumption on using double precision/not the default one if something is enforced. The main reason this implicit...

Actually, regarding the issue about `set_tensors_from_ndarray_1d`: we did change this in the past, IIRC in response to some flaky tests: https://github.com/pytorch/botorch/pull/1508/files - That doesn't mean the current setup is the...

> For the second one, I'll also see if I can come up with a solution, either in the same PR or in another one (which I guess would be...

> The string parameter could still be useful for optimisation based approaches as there are generation strategies being pushed for text-serialisable problems such as molecule properties (smilies) and protein design...

1. Yes, this is important. We'll routinely want to draw multiple samples from a mvn of given batch and event shape, so we pass in base samples with additional leading...

Sorry haven’t been able to look into this more; let me try to run this over the weekend.

OK so the failure mode I'm running into on this is the following: ``` mvn = MultivariateNormal(torch.rand(3), torch.eye(3)) mtmvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn]) mtmvn.rsample(torch.Size([4])) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last)...