[Bug] Inefficient posterior evaluation of `SaasFullyBayesianSingleTaskGP` when `q=1`
🐛 Bug
Evaluating an acquisition function with q=1 with SaasFullyBayesianSingleTaskGP requires an unnecessarily large amount of memory, due to an inefficient broadcasted matmul operation.
In the example below, the following line multiplies a tensor of size [256, 16, 1, 2048] with a tensor of size [16, 2048, 2048] which requires the allocation of 128GB of memory:
https://github.com/cornellius-gp/gpytorch/blob/9551eba889adf835b69cfd86e9a5d584fb61cdcc/gpytorch/models/exact_prediction_strategies.py#L118
To reproduce
** Code snippet to reproduce **
import torch
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms import Standardize
from botorch import fit_fully_bayesian_model_nuts
from botorch.acquisition import UpperConfidenceBound
n_train = 2048
n_test = 256
d = 256
tkwargs = {
"device": torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
"dtype": torch.double,
}
train_X = torch.rand(n_train, d, **tkwargs)
test_X = torch.rand(n_test, d, **tkwargs)
train_Y = torch.sin(train_X[:, :1])
test_Y = torch.sin(test_X[:, :1])
gp = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
outcome_transform=Standardize(m=1),
)
fit_fully_bayesian_model_nuts(
gp,
warmup_steps=4,
num_samples=16,
thinning=1,
)
ucb = UpperConfidenceBound(gp, beta=2.5)
acq_values = ucb(test_X[:, None, :])
** Stack trace/error message **
Traceback (most recent call last):
File "/tmp/ipykernel_3377365/3398296989.py", line 3, in <module>
acq_values = ucb(test_X[:, None, :])
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
mean, sigma = self._mean_and_sigma(X)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
posterior = self.model.posterior(
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
posterior = super().posterior(
^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
mvn = self(X)
^^^^^^^
...
return test_train_covar.matmul(precomputed_cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 3 has a total capacity of 79.15 GiB of which 44.45 GiB is free. Including non-PyTorch memory, this process has 34.69 GiB memory in use. Of the allocated memory 23.74 GiB is allocated by PyTorch, and 10.42 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Expected Behavior
The memory usage for this operation is very high because torch.matmul is inefficient for such batched matrix-vector multiplications. If the same operation is written as an einsum, or transposing such that it's a matrix-matrix multiplication, the memory usage and computation time are substantially reduced.
For example, below is a demonstration of two alternative operations which reduce the memory and computation time by orders of magnitude:
import torch
device = "cuda:3"
# Matrices to multiply
torch.manual_seed(50)
a = torch.randn((256, 16, 1, 1024), device=device)
b = torch.randn((16, 1024, 1024), device=device)
def profile(func):
torch.cuda.reset_peak_memory_stats(device=device)
m0 = torch.cuda.max_memory_allocated(device=device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
out = func()
end.record()
torch.cuda.synchronize()
t = start.elapsed_time(end)
m1 = torch.cuda.max_memory_allocated(device=device)
print(f"Memory used: {(m1 - m0) / 1024**3:.2f}GB")
print(f"Time: {1e3 * t:.6f} ms")
return out
with torch.no_grad():
print("matmul")
c = profile(lambda: torch.matmul(a, b))
print("\neinsum")
c_einsum = profile(lambda: torch.einsum("...ij,...jk", a, b))
print(f"Max error: {(c_einsum - c).abs().max().cpu().item():.7f}")
print("\ntransposed matmul")
c_transpose = profile(lambda: torch.matmul(a.transpose(0, 2), b).transpose(0, 2))
print(f"Max error: {(c_transpose - c).abs().max().cpu().item():.7f}")
matmul
Memory used: 16.02GB
Time: 261.343986 ms
einsum
Memory used: 0.02GB
Time: 160.416007 ms
Max error: 0.0002327
transposed matmul
Memory used: 0.02GB
Time: 118.303999 ms
Max error: 0.0002327
System information
Please complete the following information:
- BoTorch Version 1.11
- GPyTorch Version 0.9.5
- PyTorch Version 2.2.0+cu121
- Computer OS: Rocky Linux release 8.9
- GPU: NVIDIA A100 80GB PCIe
Thanks for raising this, this is a great catch. Since this call happens in gpytorch we'll have to make a change there and ensure that it is compatible with all kinds of other (non-fully-Bayesian) scenarios (not sure what kinds of shape exactly are encountered in this call), but we will definitely fix this.
cc @dme65, @esantorella
Thank you! I did initially try and come up with something to contribute to GPyTorch and/or linear_operator, but it was harder than anticipated to make it compatible and not introduce slowdowns in other situations, so I thought I'd report it here for now. For example, einsum is faster than matmul in this specific situation but has the potential to be much slower in other situations.
(n.b. I've just corrected a small mistake in the profiling code in the issue above)
I did initially try and come up with something to contribute to GPyTorch and/or linear_operator, but it was harder than anticipated to make it compatible and not introduce slowdowns in other situations, so I thought I'd report it here for now.
Interesting. Do you happen to have have some artifacts of those attempts that you could share? That would be very helpful.
What I've done at the moment is just replace the aforementioned matmul with the equivalent einsum to unblock my work, but that of course only works when neither of the tensors are a LinearOperator (as einsum is not implemented).
Below is a script that benchmarks matmul against einsum and shows that matmul is generally faster except in some specific situations (I had to restrict the tensor sizes to run on a laptop, as running on a shared server was interfering with the timing results too much). It does not report memory usage.
Those specific situations when matmul is inefficient appear to be:
ais a (batched) row vector and one of the batch dimensions ofbis broadcastedbis a (batched) column vector and one of the batch dimensions ofais broadcasted
In both situations, I'd guess that the appropriate set of matrix transposes and matmul would be more efficient than einsum, but I haven't tested this.
Code
import torch
import torch.utils.benchmark as benchmark
from tqdm import tqdm
device = "cuda:0"
f_out = "einsum.txt"
def wrap(f):
def wrapped(a_sz, b_sz, device):
a = torch.randn(a_sz, device=device)
b = torch.randn(b_sz, device=device)
return f(a, b)
return wrapped
matmul = wrap(torch.matmul)
@wrap
def einsum(a, b):
return torch.einsum("...ik,...kj", a, b)
if __name__ == "__main__":
sizes = []
batch_funcs = [
("Full", lambda i_batch, j_batch: ((i_batch, j_batch), (i_batch, j_batch))),
("No a0", lambda i_batch, j_batch: ((j_batch,), (i_batch, j_batch))),
("No b0", lambda i_batch, j_batch: ((i_batch, j_batch), (j_batch,))),
]
for batch_type, batch_func in batch_funcs:
for i_batch in (256, 32, 1):
for j_batch in (32, 16, 1):
a_batch, b_batch = batch_func(i_batch, j_batch)
if None in (a_batch, b_batch):
continue
for i_size in (128, 64, 4):
sizes.append(
(
batch_type,
"Matrix-matrix product",
a_batch + (i_size, i_size),
b_batch + (i_size, i_size),
)
)
sizes.append(
(
batch_type,
"Matrix-vector product",
a_batch + (i_size, i_size),
b_batch + (i_size, 1),
)
)
sizes.append(
(
batch_type,
"Transposed MVP",
a_batch + (1, i_size),
b_batch + (i_size, i_size),
)
)
sizes.append(
(
batch_type,
"Vector outer product",
a_batch + (i_size, 1),
b_batch + (1, i_size),
)
)
sizes.append(
(
batch_type,
"Vector inner product",
a_batch + (1, i_size),
b_batch + (i_size, 1),
)
)
results = []
with torch.no_grad():
pbar = tqdm(sizes)
for env, label, a_sz, b_sz in pbar:
sub_label = f"{a_sz}x{b_sz}"
pbar.set_description(sub_label)
timers = [
benchmark.Timer(
stmt="matmul(a_sz, b_sz, device)",
globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
description="matmul",
setup="from __main__ import matmul",
label=label,
sub_label=sub_label,
env=env,
),
benchmark.Timer(
stmt="einsum(a_sz, b_sz, device)",
globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
description="einsum",
setup="from __main__ import einsum",
label=label,
sub_label=sub_label,
env=env,
),
]
for timer in timers:
result = timer.adaptive_autorange(min_run_time=1)
results.append(result)
pbar.write(str(result))
compare = benchmark.Compare(results)
compare.colorize(rowwise=True)
with open(f_out, "wt") as f:
f.write(str(compare))
Timing results
[------------------------ Matrix-matrix product -------------------------]
| matmul | einsum
1 threads: ---------------------------------------------------------------
(Full) (256, 32, 128, 128)x(256, 32, 128, 128) | 39434.7 | 39591.2
(256, 32, 64, 64)x(256, 32, 64, 64) | 5629.9 | 5642.9
(256, 32, 4, 4)x(256, 32, 4, 4) | 170.2 | 170.6
(256, 16, 128, 128)x(256, 16, 128, 128) | 20098.4 | 20091.8
(256, 16, 64, 64)x(256, 16, 64, 64) | 2838.9 | 2837.1
(256, 16, 4, 4)x(256, 16, 4, 4) | 82.0 | 82.4
(256, 1, 128, 128)x(256, 1, 128, 128) | 1260.6 | 1262.0
(256, 1, 64, 64)x(256, 1, 64, 64) | 196.1 | 196.2
(256, 1, 4, 4)x(256, 1, 4, 4) | 59.6 | 81.5
(32, 32, 128, 128)x(32, 32, 128, 128) | 4865.4 | 5037.6
(32, 32, 64, 64)x(32, 32, 64, 64) | 724.0 | 724.2
(32, 32, 4, 4)x(32, 32, 4, 4) | 52.7 | 92.4
(32, 16, 128, 128)x(32, 16, 128, 128) | 2511.0 | 2523.9
(32, 16, 64, 64)x(32, 16, 64, 64) | 371.7 | 371.6
(32, 16, 4, 4)x(32, 16, 4, 4) | 64.1 | 88.7
(32, 1, 128, 128)x(32, 1, 128, 128) | 170.6 | 172.2
(32, 1, 64, 64)x(32, 1, 64, 64) | 62.5 | 93.1
(32, 1, 4, 4)x(32, 1, 4, 4) | 57.9 | 80.1
(1, 32, 128, 128)x(1, 32, 128, 128) | 170.2 | 172.1
(1, 32, 64, 64)x(1, 32, 64, 64) | 51.7 | 81.9
(1, 32, 4, 4)x(1, 32, 4, 4) | 55.1 | 79.1
(1, 16, 128, 128)x(1, 16, 128, 128) | 75.1 | 91.9
(1, 16, 64, 64)x(1, 16, 64, 64) | 64.0 | 79.5
(1, 16, 4, 4)x(1, 16, 4, 4) | 56.4 | 79.8
(1, 1, 128, 128)x(1, 1, 128, 128) | 66.8 | 82.9
(1, 1, 64, 64)x(1, 1, 64, 64) | 68.6 | 83.6
(1, 1, 4, 4)x(1, 1, 4, 4) | 64.2 | 82.8
(No a0) (32, 128, 128)x(256, 32, 128, 128) | 41409.0 | 41664.8
(32, 64, 64)x(256, 32, 64, 64) | 5943.2 | 5717.3
(32, 4, 4)x(256, 32, 4, 4) | 170.0 | 108.7
(16, 128, 128)x(256, 16, 128, 128) | 20606.6 | 20651.3
(16, 64, 64)x(256, 16, 64, 64) | 2697.3 | 2869.0
(16, 4, 4)x(256, 16, 4, 4) | 80.8 | 99.6
(1, 128, 128)x(256, 1, 128, 128) | 1102.7 | 955.0
(1, 64, 64)x(256, 1, 64, 64) | 136.9 | 188.1
(1, 4, 4)x(256, 1, 4, 4) | 65.9 | 104.4
(32, 128, 128)x(32, 32, 128, 128) | 5099.2 | 5208.4
(32, 64, 64)x(32, 32, 64, 64) | 771.4 | 742.5
(32, 4, 4)x(32, 32, 4, 4) | 86.0 | 102.9
(16, 128, 128)x(32, 16, 128, 128) | 2587.0 | 2606.2
(16, 64, 64)x(32, 16, 64, 64) | 359.4 | 378.0
(16, 4, 4)x(32, 16, 4, 4) | 92.2 | 117.4
(1, 128, 128)x(32, 1, 128, 128) | 151.2 | 137.0
(1, 64, 64)x(32, 1, 64, 64) | 65.3 | 109.1
(1, 4, 4)x(32, 1, 4, 4) | 73.6 | 115.5
(32, 128, 128)x(1, 32, 128, 128) | 170.8 | 172.3
(32, 64, 64)x(1, 32, 64, 64) | 55.8 | 83.5
(32, 4, 4)x(1, 32, 4, 4) | 60.1 | 81.1
(16, 128, 128)x(1, 16, 128, 128) | 75.4 | 88.6
(16, 64, 64)x(1, 16, 64, 64) | 57.9 | 81.4
(16, 4, 4)x(1, 16, 4, 4) | 58.3 | 80.7
(1, 128, 128)x(1, 1, 128, 128) | 68.5 | 81.7
(1, 64, 64)x(1, 1, 64, 64) | 69.5 | 87.2
(1, 4, 4)x(1, 1, 4, 4) | 73.2 | 97.5
(No b0) (256, 32, 128, 128)x(32, 128, 128) | 41350.2 | 41628.3
(256, 32, 64, 64)x(32, 64, 64) | 5935.2 | 6984.1
(256, 32, 4, 4)x(32, 4, 4) | 169.9 | 118.6
(256, 16, 128, 128)x(16, 128, 128) | 20609.6 | 20646.5
(256, 16, 64, 64)x(16, 64, 64) | 2692.2 | 3489.1
(256, 16, 4, 4)x(16, 4, 4) | 80.9 | 112.8
(256, 1, 128, 128)x(1, 128, 128) | 1101.0 | 753.1
(256, 1, 64, 64)x(1, 64, 64) | 139.5 | 178.4
(256, 1, 4, 4)x(1, 4, 4) | 66.5 | 81.2
(32, 32, 128, 128)x(32, 128, 128) | 5085.8 | 5191.4
(32, 32, 64, 64)x(32, 64, 64) | 766.8 | 894.6
(32, 32, 4, 4)x(32, 4, 4) | 96.6 | 163.8
(32, 16, 128, 128)x(16, 128, 128) | 2607.6 | 2612.5
(32, 16, 64, 64)x(16, 64, 64) | 358.3 | 455.5
(32, 16, 4, 4)x(16, 4, 4) | 70.5 | 105.0
(32, 1, 128, 128)x(1, 128, 128) | 150.8 | 109.2
(32, 1, 64, 64)x(1, 64, 64) | 77.3 | 94.8
(32, 1, 4, 4)x(1, 4, 4) | 65.8 | 89.3
(1, 32, 128, 128)x(32, 128, 128) | 170.2 | 171.9
(1, 32, 64, 64)x(32, 64, 64) | 51.6 | 83.7
(1, 32, 4, 4)x(32, 4, 4) | 68.6 | 98.6
(1, 16, 128, 128)x(16, 128, 128) | 74.6 | 86.9
(1, 16, 64, 64)x(16, 64, 64) | 73.1 | 100.9
(1, 16, 4, 4)x(16, 4, 4) | 71.5 | 98.6
(1, 1, 128, 128)x(1, 128, 128) | 78.0 | 99.1
(1, 1, 64, 64)x(1, 64, 64) | 68.6 | 100.3
(1, 1, 4, 4)x(1, 4, 4) | 73.6 | 95.4
Times are in microseconds (us).
[----------------------- Matrix-vector product ------------------------]
| matmul | einsum
1 threads: -------------------------------------------------------------
(Full) (256, 32, 128, 128)x(256, 32, 128, 1) | 8758.8 | 8771.6
(256, 32, 64, 64)x(256, 32, 64, 1) | 2265.5 | 2257.6
(256, 32, 4, 4)x(256, 32, 4, 1) | 64.4 | 84.1
(256, 16, 128, 128)x(256, 16, 128, 1) | 4432.1 | 4430.9
(256, 16, 64, 64)x(256, 16, 64, 1) | 1132.3 | 1131.9
(256, 16, 4, 4)x(256, 16, 4, 1) | 57.5 | 80.1
(256, 1, 128, 128)x(256, 1, 128, 1) | 287.1 | 287.1
(256, 1, 64, 64)x(256, 1, 64, 1) | 82.3 | 85.1
(256, 1, 4, 4)x(256, 1, 4, 1) | 57.9 | 81.1
(32, 32, 128, 128)x(32, 32, 128, 1) | 1098.5 | 1108.3
(32, 32, 64, 64)x(32, 32, 64, 1) | 295.7 | 295.8
(32, 32, 4, 4)x(32, 32, 4, 1) | 64.2 | 82.5
(32, 16, 128, 128)x(32, 16, 128, 1) | 558.9 | 558.7
(32, 16, 64, 64)x(32, 16, 64, 1) | 152.5 | 152.5
(32, 16, 4, 4)x(32, 16, 4, 1) | 64.8 | 80.1
(32, 1, 128, 128)x(32, 1, 128, 1) | 61.9 | 92.6
(32, 1, 64, 64)x(32, 1, 64, 1) | 64.6 | 87.5
(32, 1, 4, 4)x(32, 1, 4, 1) | 60.7 | 80.5
(1, 32, 128, 128)x(1, 32, 128, 1) | 50.8 | 87.1
(1, 32, 64, 64)x(1, 32, 64, 1) | 55.7 | 79.5
(1, 32, 4, 4)x(1, 32, 4, 1) | 61.6 | 80.8
(1, 16, 128, 128)x(1, 16, 128, 1) | 55.0 | 80.2
(1, 16, 64, 64)x(1, 16, 64, 1) | 55.8 | 80.5
(1, 16, 4, 4)x(1, 16, 4, 1) | 58.6 | 80.1
(1, 1, 128, 128)x(1, 1, 128, 1) | 55.8 | 81.2
(1, 1, 64, 64)x(1, 1, 64, 1) | 58.5 | 81.3
(1, 1, 4, 4)x(1, 1, 4, 1) | 71.3 | 103.5
(No a0) (32, 128, 128)x(256, 32, 128, 1) | 11368.5 | 317.5
(32, 64, 64)x(256, 32, 64, 1) | 2704.4 | 89.7
(32, 4, 4)x(256, 32, 4, 1) | 78.8 | 83.6
(16, 128, 128)x(256, 16, 128, 1) | 5688.9 | 167.6
(16, 64, 64)x(256, 16, 64, 1) | 1029.9 | 84.1
(16, 4, 4)x(256, 16, 4, 1) | 79.0 | 89.9
(1, 128, 128)x(256, 1, 128, 1) | 126.6 | 89.2
(1, 64, 64)x(256, 1, 64, 1) | 80.8 | 106.6
(1, 4, 4)x(256, 1, 4, 1) | 66.1 | 89.6
(32, 128, 128)x(32, 32, 128, 1) | 1448.9 | 85.4
(32, 64, 64)x(32, 32, 64, 1) | 354.7 | 84.9
(32, 4, 4)x(32, 32, 4, 1) | 72.8 | 84.1
(16, 128, 128)x(32, 16, 128, 1) | 733.7 | 83.7
(16, 64, 64)x(32, 16, 64, 1) | 141.3 | 100.6
(16, 4, 4)x(32, 16, 4, 1) | 92.3 | 104.4
(1, 128, 128)x(32, 1, 128, 1) | 84.9 | 94.6
(1, 64, 64)x(32, 1, 64, 1) | 73.6 | 105.4
(1, 4, 4)x(32, 1, 4, 1) | 75.4 | 101.3
(32, 128, 128)x(1, 32, 128, 1) | 67.1 | 91.3
(32, 64, 64)x(1, 32, 64, 1) | 59.2 | 82.1
(32, 4, 4)x(1, 32, 4, 1) | 63.9 | 81.2
(16, 128, 128)x(1, 16, 128, 1) | 57.5 | 83.0
(16, 64, 64)x(1, 16, 64, 1) | 57.7 | 83.1
(16, 4, 4)x(1, 16, 4, 1) | 65.2 | 81.3
(1, 128, 128)x(1, 1, 128, 1) | 57.7 | 81.2
(1, 64, 64)x(1, 1, 64, 1) | 58.4 | 80.3
(1, 4, 4)x(1, 1, 4, 1) | 80.6 | 107.0
(No b0) (256, 32, 128, 128)x(32, 128, 1) | 8751.5 | 15982.3
(256, 32, 64, 64)x(32, 64, 1) | 2280.4 | 4043.9
(256, 32, 4, 4)x(32, 4, 1) | 84.3 | 124.5
(256, 16, 128, 128)x(16, 128, 1) | 4419.8 | 8036.1
(256, 16, 64, 64)x(16, 64, 1) | 1131.8 | 2028.7
(256, 16, 4, 4)x(16, 4, 1) | 72.7 | 107.3
(256, 1, 128, 128)x(1, 128, 1) | 283.3 | 271.1
(256, 1, 64, 64)x(1, 64, 1) | 79.8 | 81.9
(256, 1, 4, 4)x(1, 4, 1) | 67.4 | 96.3
(32, 32, 128, 128)x(32, 128, 1) | 1105.2 | 1999.8
(32, 32, 64, 64)x(32, 64, 1) | 292.8 | 518.5
(32, 32, 4, 4)x(32, 4, 1) | 100.1 | 134.8
(32, 16, 128, 128)x(16, 128, 1) | 559.8 | 1013.7
(32, 16, 64, 64)x(16, 64, 1) | 152.6 | 265.0
(32, 16, 4, 4)x(16, 4, 1) | 82.8 | 111.6
(32, 1, 128, 128)x(1, 128, 1) | 57.8 | 82.0
(32, 1, 64, 64)x(1, 64, 1) | 69.1 | 82.4
(32, 1, 4, 4)x(1, 4, 1) | 72.8 | 95.6
(1, 32, 128, 128)x(32, 128, 1) | 53.0 | 82.7
(1, 32, 64, 64)x(32, 64, 1) | 56.4 | 98.6
(1, 32, 4, 4)x(32, 4, 1) | 76.3 | 97.1
(1, 16, 128, 128)x(16, 128, 1) | 55.6 | 98.5
(1, 16, 64, 64)x(16, 64, 1) | 67.6 | 95.9
(1, 16, 4, 4)x(16, 4, 1) | 76.1 | 97.6
(1, 1, 128, 128)x(1, 128, 1) | 65.4 | 96.5
(1, 1, 64, 64)x(1, 64, 1) | 69.7 | 97.9
(1, 1, 4, 4)x(1, 4, 1) | 80.8 | 109.2
Times are in microseconds (us).
[--------------------------- Transposed MVP ---------------------------]
| matmul | einsum
1 threads: -------------------------------------------------------------
(Full) (256, 32, 1, 128)x(256, 32, 128, 128) | 8156.9 | 8164.4
(256, 32, 1, 64)x(256, 32, 64, 64) | 2079.4 | 2079.0
(256, 32, 1, 4)x(256, 32, 4, 4) | 69.6 | 88.3
(256, 16, 1, 128)x(256, 16, 128, 128) | 4097.1 | 4095.3
(256, 16, 1, 64)x(256, 16, 64, 64) | 1049.0 | 1048.5
(256, 16, 1, 4)x(256, 16, 4, 4) | 54.5 | 78.8
(256, 1, 1, 128)x(256, 1, 128, 128) | 270.3 | 270.5
(256, 1, 1, 64)x(256, 1, 64, 64) | 74.5 | 75.1
(256, 1, 1, 4)x(256, 1, 4, 4) | 57.6 | 80.5
(32, 32, 1, 128)x(32, 32, 128, 128) | 1036.2 | 1036.7
(32, 32, 1, 64)x(32, 32, 64, 64) | 274.2 | 274.4
(32, 32, 1, 4)x(32, 32, 4, 4) | 54.2 | 81.5
(32, 16, 1, 128)x(32, 16, 128, 128) | 530.1 | 529.6
(32, 16, 1, 64)x(32, 16, 64, 64) | 141.4 | 141.7
(32, 16, 1, 4)x(32, 16, 4, 4) | 65.6 | 85.6
(32, 1, 1, 128)x(32, 1, 128, 128) | 69.3 | 121.2
(32, 1, 1, 64)x(32, 1, 64, 64) | 62.4 | 80.2
(32, 1, 1, 4)x(32, 1, 4, 4) | 58.5 | 80.4
(1, 32, 1, 128)x(1, 32, 128, 128) | 54.5 | 90.1
(1, 32, 1, 64)x(1, 32, 64, 64) | 58.0 | 81.3
(1, 32, 1, 4)x(1, 32, 4, 4) | 57.5 | 79.8
(1, 16, 1, 128)x(1, 16, 128, 128) | 65.3 | 81.9
(1, 16, 1, 64)x(1, 16, 64, 64) | 57.5 | 90.7
(1, 16, 1, 4)x(1, 16, 4, 4) | 59.5 | 79.7
(1, 1, 1, 128)x(1, 1, 128, 128) | 69.3 | 79.4
(1, 1, 1, 64)x(1, 1, 64, 64) | 68.1 | 82.0
(1, 1, 1, 4)x(1, 1, 4, 4) | 73.4 | 82.1
(No a0) (32, 1, 128)x(256, 32, 128, 128) | 8155.8 | 15325.0
(32, 1, 64)x(256, 32, 64, 64) | 2071.2 | 3938.2
(32, 1, 4)x(256, 32, 4, 4) | 71.9 | 103.3
(16, 1, 128)x(256, 16, 128, 128) | 4085.8 | 7668.6
(16, 1, 64)x(256, 16, 64, 64) | 1045.5 | 1971.8
(16, 1, 4)x(256, 16, 4, 4) | 79.8 | 102.1
(1, 1, 128)x(256, 1, 128, 128) | 264.5 | 496.9
(1, 1, 64)x(256, 1, 64, 64) | 72.9 | 143.9
(1, 1, 4)x(256, 1, 4, 4) | 68.6 | 110.5
(32, 1, 128)x(32, 32, 128, 128) | 1032.5 | 1931.6
(32, 1, 64)x(32, 32, 64, 64) | 271.5 | 509.4
(32, 1, 4)x(32, 32, 4, 4) | 80.5 | 120.2
(16, 1, 128)x(32, 16, 128, 128) | 524.2 | 979.9
(16, 1, 64)x(32, 16, 64, 64) | 142.5 | 256.6
(16, 1, 4)x(32, 16, 4, 4) | 103.2 | 119.8
(1, 1, 128)x(32, 1, 128, 128) | 61.2 | 102.4
(1, 1, 64)x(32, 1, 64, 64) | 69.2 | 121.2
(1, 1, 4)x(32, 1, 4, 4) | 75.3 | 119.2
(32, 1, 128)x(1, 32, 128, 128) | 70.1 | 92.3
(32, 1, 64)x(1, 32, 64, 64) | 66.5 | 81.0
(32, 1, 4)x(1, 32, 4, 4) | 66.7 | 81.0
(16, 1, 128)x(1, 16, 128, 128) | 66.1 | 85.0
(16, 1, 64)x(1, 16, 64, 64) | 58.3 | 83.0
(16, 1, 4)x(1, 16, 4, 4) | 61.5 | 82.2
(1, 1, 128)x(1, 1, 128, 128) | 69.4 | 79.5
(1, 1, 64)x(1, 1, 64, 64) | 67.1 | 79.2
(1, 1, 4)x(1, 1, 4, 4) | 76.8 | 97.8
(No b0) (256, 32, 1, 128)x(32, 128, 128) | 11310.7 | 307.8
(256, 32, 1, 64)x(32, 64, 64) | 2654.2 | 103.4
(256, 32, 1, 4)x(32, 4, 4) | 88.4 | 101.2
(256, 16, 1, 128)x(16, 128, 128) | 5672.2 | 161.3
(256, 16, 1, 64)x(16, 64, 64) | 971.0 | 97.6
(256, 16, 1, 4)x(16, 4, 4) | 71.9 | 93.0
(256, 1, 1, 128)x(1, 128, 128) | 63.8 | 104.2
(256, 1, 1, 64)x(1, 64, 64) | 72.3 | 84.2
(256, 1, 1, 4)x(1, 4, 4) | 58.9 | 98.4
(32, 32, 1, 128)x(32, 128, 128) | 1444.0 | 97.6
(32, 32, 1, 64)x(32, 64, 64) | 346.9 | 119.0
(32, 32, 1, 4)x(32, 4, 4) | 101.3 | 119.3
(32, 16, 1, 128)x(16, 128, 128) | 731.4 | 104.3
(32, 16, 1, 64)x(16, 64, 64) | 134.4 | 95.6
(32, 16, 1, 4)x(16, 4, 4) | 79.5 | 96.8
(32, 1, 1, 128)x(1, 128, 128) | 70.0 | 83.4
(32, 1, 1, 64)x(1, 64, 64) | 67.0 | 98.1
(32, 1, 1, 4)x(1, 4, 4) | 74.5 | 94.4
(1, 32, 1, 128)x(32, 128, 128) | 54.4 | 90.8
(1, 32, 1, 64)x(32, 64, 64) | 64.4 | 92.0
(1, 32, 1, 4)x(32, 4, 4) | 71.5 | 102.9
(1, 16, 1, 128)x(16, 128, 128) | 64.8 | 100.4
(1, 16, 1, 64)x(16, 64, 64) | 69.4 | 97.8
(1, 16, 1, 4)x(16, 4, 4) | 75.3 | 98.7
(1, 1, 1, 128)x(1, 128, 128) | 69.8 | 99.9
(1, 1, 1, 64)x(1, 64, 64) | 76.9 | 95.5
(1, 1, 1, 4)x(1, 4, 4) | 76.7 | 108.9
Times are in microseconds (us).
[---------------------- Vector outer product ----------------------]
| matmul | einsum
1 threads: ---------------------------------------------------------
(Full) (256, 32, 128, 1)x(256, 32, 1, 128) | 3857.8 | 5171.4
(256, 32, 64, 1)x(256, 32, 1, 64) | 1042.8 | 1337.9
(256, 32, 4, 1)x(256, 32, 1, 4) | 93.5 | 62.8
(256, 16, 128, 1)x(256, 16, 1, 128) | 1938.2 | 2597.6
(256, 16, 64, 1)x(256, 16, 1, 64) | 528.5 | 670.7
(256, 16, 4, 1)x(256, 16, 1, 4) | 52.2 | 63.0
(256, 1, 128, 1)x(256, 1, 1, 128) | 131.4 | 168.9
(256, 1, 64, 1)x(256, 1, 1, 64) | 56.0 | 66.5
(256, 1, 4, 1)x(256, 1, 1, 4) | 64.9 | 67.7
(32, 32, 128, 1)x(32, 32, 1, 128) | 497.3 | 661.1
(32, 32, 64, 1)x(32, 32, 1, 64) | 141.3 | 174.8
(32, 32, 4, 1)x(32, 32, 1, 4) | 54.7 | 62.3
(32, 16, 128, 1)x(32, 16, 1, 128) | 254.0 | 334.6
(32, 16, 64, 1)x(32, 16, 1, 64) | 75.3 | 90.6
(32, 16, 4, 1)x(32, 16, 1, 4) | 57.8 | 63.6
(32, 1, 128, 1)x(32, 1, 1, 128) | 78.6 | 78.3
(32, 1, 64, 1)x(32, 1, 1, 64) | 55.1 | 62.6
(32, 1, 4, 1)x(32, 1, 1, 4) | 58.0 | 66.9
(1, 32, 128, 1)x(1, 32, 1, 128) | 54.7 | 63.0
(1, 32, 64, 1)x(1, 32, 1, 64) | 64.1 | 62.1
(1, 32, 4, 1)x(1, 32, 1, 4) | 58.2 | 63.1
(1, 16, 128, 1)x(1, 16, 1, 128) | 59.4 | 64.1
(1, 16, 64, 1)x(1, 16, 1, 64) | 59.2 | 71.7
(1, 16, 4, 1)x(1, 16, 1, 4) | 60.0 | 64.8
(1, 1, 128, 1)x(1, 1, 1, 128) | 62.0 | 62.8
(1, 1, 64, 1)x(1, 1, 1, 64) | 62.2 | 66.4
(1, 1, 4, 1)x(1, 1, 1, 4) | 61.1 | 65.9
(No a0) (32, 128, 1)x(256, 32, 1, 128) | 3853.3 | 5370.5
(32, 64, 1)x(256, 32, 1, 64) | 1041.9 | 1391.3
(32, 4, 1)x(256, 32, 1, 4) | 94.2 | 63.8
(16, 128, 1)x(256, 16, 1, 128) | 1937.8 | 2583.4
(16, 64, 1)x(256, 16, 1, 64) | 529.2 | 700.7
(16, 4, 1)x(256, 16, 1, 4) | 80.1 | 63.6
(1, 128, 1)x(256, 1, 1, 128) | 123.9 | 158.6
(1, 64, 1)x(256, 1, 1, 64) | 61.5 | 63.0
(1, 4, 1)x(256, 1, 1, 4) | 62.0 | 70.0
(32, 128, 1)x(32, 32, 1, 128) | 497.1 | 681.7
(32, 64, 1)x(32, 32, 1, 64) | 142.2 | 181.4
(32, 4, 1)x(32, 32, 1, 4) | 88.5 | 76.3
(16, 128, 1)x(32, 16, 1, 128) | 255.0 | 330.6
(16, 64, 1)x(32, 16, 1, 64) | 82.6 | 94.2
(16, 4, 1)x(32, 16, 1, 4) | 94.0 | 77.1
(1, 128, 1)x(32, 1, 1, 128) | 61.8 | 64.7
(1, 64, 1)x(32, 1, 1, 64) | 80.0 | 83.9
(1, 4, 1)x(32, 1, 1, 4) | 74.9 | 79.5
(32, 128, 1)x(1, 32, 1, 128) | 66.2 | 74.2
(32, 64, 1)x(1, 32, 1, 64) | 59.5 | 64.3
(32, 4, 1)x(1, 32, 1, 4) | 63.0 | 67.1
(16, 128, 1)x(1, 16, 1, 128) | 61.8 | 63.3
(16, 64, 1)x(1, 16, 1, 64) | 61.1 | 69.9
(16, 4, 1)x(1, 16, 1, 4) | 64.5 | 67.6
(1, 128, 1)x(1, 1, 1, 128) | 65.6 | 65.8
(1, 64, 1)x(1, 1, 1, 64) | 60.8 | 66.4
(1, 4, 1)x(1, 1, 1, 4) | 76.1 | 78.9
(No b0) (256, 32, 128, 1)x(32, 1, 128) | 3853.5 | 5463.5
(256, 32, 64, 1)x(32, 1, 64) | 1041.5 | 1371.9
(256, 32, 4, 1)x(32, 1, 4) | 94.3 | 76.2
(256, 16, 128, 1)x(16, 1, 128) | 1937.0 | 2714.4
(256, 16, 64, 1)x(16, 1, 64) | 520.2 | 691.9
(256, 16, 4, 1)x(16, 1, 4) | 81.0 | 73.1
(256, 1, 128, 1)x(1, 1, 128) | 124.1 | 152.4
(256, 1, 64, 1)x(1, 1, 64) | 58.2 | 62.4
(256, 1, 4, 1)x(1, 1, 4) | 76.5 | 82.5
(32, 32, 128, 1)x(32, 1, 128) | 495.6 | 687.8
(32, 32, 64, 1)x(32, 1, 64) | 141.8 | 179.0
(32, 32, 4, 1)x(32, 1, 4) | 77.4 | 66.9
(32, 16, 128, 1)x(16, 1, 128) | 254.5 | 345.6
(32, 16, 64, 1)x(16, 1, 64) | 79.6 | 92.6
(32, 16, 4, 1)x(16, 1, 4) | 81.8 | 82.6
(32, 1, 128, 1)x(1, 1, 128) | 59.4 | 77.1
(32, 1, 64, 1)x(1, 1, 64) | 59.5 | 70.9
(32, 1, 4, 1)x(1, 1, 4) | 69.8 | 74.9
(1, 32, 128, 1)x(32, 1, 128) | 54.6 | 63.7
(1, 32, 64, 1)x(32, 1, 64) | 71.6 | 82.0
(1, 32, 4, 1)x(32, 1, 4) | 72.1 | 79.2
(1, 16, 128, 1)x(16, 1, 128) | 78.2 | 87.9
(1, 16, 64, 1)x(16, 1, 64) | 68.8 | 78.0
(1, 16, 4, 1)x(16, 1, 4) | 77.0 | 79.3
(1, 1, 128, 1)x(1, 1, 128) | 75.9 | 78.9
(1, 1, 64, 1)x(1, 1, 64) | 75.2 | 76.7
(1, 1, 4, 1)x(1, 1, 4) | 76.1 | 78.1
Times are in microseconds (us).
[---------------------- Vector inner product ----------------------]
| matmul | einsum
1 threads: ---------------------------------------------------------
(Full) (256, 32, 1, 128)x(256, 32, 128, 1) | 256.9 | 257.1
(256, 32, 1, 64)x(256, 32, 64, 1) | 141.1 | 141.7
(256, 32, 1, 4)x(256, 32, 4, 1) | 57.9 | 86.4
(256, 16, 1, 128)x(256, 16, 128, 1) | 137.0 | 137.0
(256, 16, 1, 64)x(256, 16, 64, 1) | 74.9 | 85.8
(256, 16, 1, 4)x(256, 16, 4, 1) | 54.5 | 78.6
(256, 1, 1, 128)x(256, 1, 128, 1) | 65.3 | 90.8
(256, 1, 1, 64)x(256, 1, 64, 1) | 59.7 | 83.2
(256, 1, 1, 4)x(256, 1, 4, 1) | 55.1 | 80.6
(32, 32, 1, 128)x(32, 32, 128, 1) | 52.2 | 90.6
(32, 32, 1, 64)x(32, 32, 64, 1) | 65.9 | 80.0
(32, 32, 1, 4)x(32, 32, 4, 1) | 55.5 | 82.7
(32, 16, 1, 128)x(32, 16, 128, 1) | 67.0 | 80.0
(32, 16, 1, 64)x(32, 16, 64, 1) | 66.2 | 97.3
(32, 16, 1, 4)x(32, 16, 4, 1) | 68.4 | 85.7
(32, 1, 1, 128)x(32, 1, 128, 1) | 70.1 | 95.6
(32, 1, 1, 64)x(32, 1, 64, 1) | 67.7 | 79.5
(32, 1, 1, 4)x(32, 1, 4, 1) | 61.2 | 79.5
(1, 32, 1, 128)x(1, 32, 128, 1) | 56.5 | 82.0
(1, 32, 1, 64)x(1, 32, 64, 1) | 62.8 | 79.3
(1, 32, 1, 4)x(1, 32, 4, 1) | 61.6 | 79.3
(1, 16, 1, 128)x(1, 16, 128, 1) | 57.8 | 80.2
(1, 16, 1, 64)x(1, 16, 64, 1) | 64.5 | 79.7
(1, 16, 1, 4)x(1, 16, 4, 1) | 57.7 | 79.9
(1, 1, 1, 128)x(1, 1, 128, 1) | 61.4 | 88.4
(1, 1, 1, 64)x(1, 1, 64, 1) | 67.6 | 91.3
(1, 1, 1, 4)x(1, 1, 4, 1) | 63.5 | 81.8
(No a0) (32, 1, 128)x(256, 32, 128, 1) | 251.3 | 79.1
(32, 1, 64)x(256, 32, 64, 1) | 136.7 | 90.3
(32, 1, 4)x(256, 32, 4, 1) | 73.0 | 86.9
(16, 1, 128)x(256, 16, 128, 1) | 133.2 | 102.1
(16, 1, 64)x(256, 16, 64, 1) | 76.2 | 83.9
(16, 1, 4)x(256, 16, 4, 1) | 80.1 | 87.5
(1, 1, 128)x(256, 1, 128, 1) | 89.5 | 96.2
(1, 1, 64)x(256, 1, 64, 1) | 62.3 | 90.7
(1, 1, 4)x(256, 1, 4, 1) | 62.9 | 88.8
(32, 1, 128)x(32, 32, 128, 1) | 77.6 | 89.8
(32, 1, 64)x(32, 32, 64, 1) | 87.9 | 89.7
(32, 1, 4)x(32, 32, 4, 1) | 84.5 | 104.8
(16, 1, 128)x(32, 16, 128, 1) | 90.0 | 111.5
(16, 1, 64)x(32, 16, 64, 1) | 106.2 | 123.3
(16, 1, 4)x(32, 16, 4, 1) | 104.2 | 103.5
(1, 1, 128)x(32, 1, 128, 1) | 71.5 | 86.5
(1, 1, 64)x(32, 1, 64, 1) | 77.2 | 97.9
(1, 1, 4)x(32, 1, 4, 1) | 74.8 | 102.6
(32, 1, 128)x(1, 32, 128, 1) | 65.6 | 82.8
(32, 1, 64)x(1, 32, 64, 1) | 66.9 | 82.2
(32, 1, 4)x(1, 32, 4, 1) | 64.8 | 80.6
(16, 1, 128)x(1, 16, 128, 1) | 60.4 | 90.8
(16, 1, 64)x(1, 16, 64, 1) | 61.3 | 82.1
(16, 1, 4)x(1, 16, 4, 1) | 65.0 | 81.9
(1, 1, 128)x(1, 1, 128, 1) | 63.0 | 87.5
(1, 1, 64)x(1, 1, 64, 1) | 62.5 | 88.9
(1, 1, 4)x(1, 1, 4, 1) | 76.9 | 98.4
(No b0) (256, 32, 1, 128)x(32, 128, 1) | 251.1 | 98.1
(256, 32, 1, 64)x(32, 64, 1) | 137.6 | 102.3
(256, 32, 1, 4)x(32, 4, 1) | 83.4 | 103.7
(256, 16, 1, 128)x(16, 128, 1) | 133.1 | 89.2
(256, 16, 1, 64)x(16, 64, 1) | 76.2 | 86.6
(256, 16, 1, 4)x(16, 4, 1) | 88.4 | 103.9
(256, 1, 1, 128)x(1, 128, 1) | 77.3 | 95.9
(256, 1, 1, 64)x(1, 64, 1) | 58.9 | 82.1
(256, 1, 1, 4)x(1, 4, 1) | 78.2 | 96.4
(32, 32, 1, 128)x(32, 128, 1) | 94.8 | 105.0
(32, 32, 1, 64)x(32, 64, 1) | 104.9 | 111.4
(32, 32, 1, 4)x(32, 4, 1) | 88.8 | 104.2
(32, 16, 1, 128)x(16, 128, 1) | 82.6 | 89.2
(32, 16, 1, 64)x(16, 64, 1) | 81.1 | 92.9
(32, 16, 1, 4)x(16, 4, 1) | 94.1 | 100.6
(32, 1, 1, 128)x(1, 128, 1) | 70.2 | 100.2
(32, 1, 1, 64)x(1, 64, 1) | 71.3 | 83.7
(32, 1, 1, 4)x(1, 4, 1) | 70.2 | 92.5
(1, 32, 1, 128)x(32, 128, 1) | 57.1 | 83.7
(1, 32, 1, 64)x(32, 64, 1) | 71.2 | 98.5
(1, 32, 1, 4)x(32, 4, 1) | 71.6 | 90.1
(1, 16, 1, 128)x(16, 128, 1) | 67.3 | 90.7
(1, 16, 1, 64)x(16, 64, 1) | 73.2 | 98.5
(1, 16, 1, 4)x(16, 4, 1) | 74.8 | 98.3
(1, 1, 1, 128)x(1, 128, 1) | 76.7 | 105.1
(1, 1, 1, 64)x(1, 64, 1) | 81.4 | 105.8
(1, 1, 1, 4)x(1, 4, 1) | 75.7 | 107.8
Times are in microseconds (us).
Along the same lines as above - I've found another (less significant) memory saving. If I attempt to optimise an acquisition function in the original issue with the setup below, after replacing the matmul with einsum, I still run out of memory on an 80GB A100 while computing the distance matrix.
candidates, acq_values = optimize_acqf(
ucb,
bounds=torch.cat((torch.zeros(1, d), torch.ones(1, d))).to(**tkwargs),
q=1,
num_restarts=10,
raw_samples=1024,
)
Traceback:
Traceback (most recent call last):
File "/tmp/ipykernel_2227345/3762668224.py", line 5, in <module>
candidates, acq_values = optimize_acqf(
^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
return _optimize_acqf(opt_acqf_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
return _optimize_acqf_batch(opt_inputs=opt_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
batch_initial_conditions = opt_inputs.get_ic_generator()(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
Y_rnd_curr = acq_function(
^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
mean, sigma = self._mean_and_sigma(X)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
posterior = self.model.posterior(
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
posterior = super().posterior(
^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
mvn = self(X)
^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py", line 286, in exact_prediction
test_covar = joint_covar[..., self.num_train :, :].to_dense()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
return self.evaluate_kernel().to_dense()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
output = method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
res = self.kernel(
^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 532, in __call__
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
File "/home/.../lib/python3.11/site-packages/gpytorch/module.py", line 31, in __call__
outputs = self.forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/scale_kernel.py", line 109, in forward
orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/matern_kernel.py", line 99, in forward
distance = self.covar_dist(x1_, x2_, diag=diag, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 359, in covar_dist
return dist_func(x1, x2, x1_eq_x2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 59, in dist
res = torch.cdist(x1, x2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/functional.py", line 1330, in cdist
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 64.03 GiB. GPU 3 has a total capacity of 79.15 GiB of which 9.08 GiB is free. Process 3788106 has 706.00 MiB memory in use. Process 3860132 has 706.00 MiB memory in use. Including non-PyTorch memory, this process has 68.67 GiB memory in use. Of the allocated memory 68.09 GiB is allocated by PyTorch, and 51.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
If I modify the call to torch.cdist to the following, the memory usage is reduced just enough for it to work:
res = torch.cdist(x1, x2, compute_mode="donot_use_mm_for_euclid_dist")
For ref: at this point, x1 and x2 have shapes [1024, 16, 1, 256] and [1024, 16, 2049, 256] respectively.
The PyTorch documentation on compute_mode is quite vague, but https://github.com/pytorch/pytorch/issues/42479 suggests that this option is slower but slightly more accurate.
(This issue is also more of a GPyTorch issue than a BoTorch issue, but it's closely related to this so I'm just adding it here for now; let me know if you want me to create a new issue).
There is a very similar issue with KroneckerMultiTaskGP. Below is a minimum reproducible example.
import torch
from botorch.models import KroneckerMultiTaskGP
n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")
train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)
test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)
gp = KroneckerMultiTaskGP(train_x, train_y)
posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1]))
Stack trace
---------------------------------------------------------------------------
OutOfMemoryError Traceback (most recent call last)
Cell In[7], line 18
15 gp = KroneckerMultiTaskGP(train_x, train_y)
17 posterior = gp.posterior(test_x)
---> 18 posterior.rsample(torch.Size([256, 1]))
File ~/miniconda3/envs/al/lib/python3.10/site-packages/botorch/posteriors/multitask.py:269, in MultitaskGPPosterior.rsample(self, sample_shape)
267 if sample_shape is None:
268 sample_shape = torch.Size([1])
--> 269 return self.rsample_from_base_samples(
270 sample_shape=sample_shape, base_samples=None
271 )
File ~/miniconda3/envs/al/lib/python3.10/site-packages/botorch/posteriors/multitask.py:229, in MultitaskGPPosterior.rsample_from_base_samples(self, sample_shape, base_samples, train_diff)
225 obs_minus_samples = (
226 train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
227 )
228 train_covar_plus_noise = self.train_train_covar + self.train_noise
--> 229 obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))
231 # and multiply the test-observed matrix against the result of the solve
232 updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/_linear_operator.py:2334, in LinearOperator.solve(self, right_tensor, left_tensor)
2332 func = Solve
2333 if left_tensor is None:
-> 2334 return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
2335 else:
2336 return func.apply(
2337 self.representation_tree(),
2338 True,
(...)
2341 *self.representation(),
2342 )
File ~/miniconda3/envs/al/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/functions/_solve.py:53, in Solve.forward(ctx, representation_tree, has_left, *args)
51 res = left_tensor @ res
52 else:
---> 53 solves = _solve(linear_op, right_tensor)
54 res = solves
56 if ctx.is_vector:
File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/functions/_solve.py:17, in _solve(linear_op, rhs)
15 return linear_op.solve(rhs)
16 if settings.fast_computations.solves.off() or linear_op.size(-1) <= settings.max_cholesky_size.value():
---> 17 return linear_op.cholesky()._cholesky_solve(rhs)
18 else:
19 with torch.no_grad():
File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/triangular_linear_operator.py:78, in TriangularLinearOperator._cholesky_solve(self, rhs, upper)
71 def _cholesky_solve(
72 self: Float[LinearOperator, "*batch N N"],
73 rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]],
74 upper: Optional[bool] = False,
75 ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]:
76 # use custom method if implemented
77 try:
---> 78 res = self._tensor._cholesky_solve(rhs=rhs, upper=upper)
79 except NotImplementedError:
80 if upper:
81 # res = (U.T @ U)^-1 @ v = U^-1 @ U^-T @ v
File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/dense_linear_operator.py:38, in DenseLinearOperator._cholesky_solve(self, rhs, upper)
33 def _cholesky_solve(
34 self: Float[LinearOperator, "*batch N N"],
35 rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]],
36 upper: Optional[bool] = False,
37 ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]:
---> 38 return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)
OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB (GPU 0; 79.19 GiB total capacity; 4.73 GiB already allocated; 71.26 GiB free; 5.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
The final line requires allocation of 128GB of GPU memory, because of the call to torch.cholesky_solve with B shaped (256, 1, 8192, 1) and L shaped (8192, 8192).
The following line appears to be the root cause. Unsqueezing the last dimension sets up a batched matrix-vector solve, but if we instead transpose one of the batch dimensions to the end, we do a more efficient matrix-matrix solve. https://github.com/pytorch/botorch/blob/4497a5c5df7183a1f90219c818e58d172b368cf8/botorch/posteriors/multitask.py#L229
For example, the following code requires less than 4GB of GPU memory, and I believe it is equivalent. By moving the first batch dimension to the final position, we at least stand the chance of having a more efficient operation if the first batch dimension is greater than 1. It would probably be even better to find the largest batch dimension and move that one to the end, or even flatten them all.
perm = list(range(1, obs_minus_samples.ndim)) + [0]
inverse_perm = torch.argsort(torch.tensor(perm))
obs_minus_samples_p = obs_minus_samples.permute(*perm)
obs_solve = train_covar_plus_noise.solve(obs_minus_samples_p)
# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).permute(*inverse_perm)
@Balandat @esantorella: should I submit a PR for this, or is there an obvious reason that it wouldn't work in other use cases?
should I submit a PR for this, or is there an obvious reason that it wouldn't work in other use cases?
A PR for this would be great - I don't see any obvious reason why this wouldn't work in other cases. Ideally, we could even do this at the level of LinearOperator.solve() so that other cases can also benefit from this. Even more ideal, this could be done at the level of pytorch's torch.cholesky_solve() under the hood, but that would be a larger lift - though it's probably worth raising this with the pytorch folks to understand whether there are any plans in that direction. cc @gpleiss, @jacobrgardner
It would probably be even better to find the largest batch dimension and move that one to the end, or even flatten them all.
Yes, that makes sense. There are probably some nontrivial tradeoffs between flattening them all and keeping them around depending on how exactly the underlying cuda kernel parallelizes the evaluation in each case.
cc @sdaulton, @jandylin, @SebastianAment re excessive memory usage in Kronecker MTGPs
cc @JonathanWenger
Hi, I had a similar issue and got some useful insights from this post. However, now I am encountering a memory error in matern_kernel.py at this line- : x2_ = (x2 - mean).div(self.lengthscale)
Here are my x1, mean and x2 shapes at this point. x1.shape = torch.Size([32, 50, 16, 1, 100]) mean.shape = torch.Size([32, 50, 16, 1, 100]) x2.shape = torch.Size([32, 50, 16, 1611, 100]) I am using qknowledgeGradient in optimise_acqf with raw_samples = 50, q=1, num_samples(because I am using SaasFullyBayesianSingleTaskGP) = 256, thinning = 16.
The error I get is - line 101, in forward x2_ = (x2 - mean).div(self.lengthscale) RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 32993280000 bytes.
Is there any way to fix this? Thanks !