Batch convolution
This PR addresses issues in #345. In particular, it introduces the possibility of batching the convolution over input channels and filters (basis functions usually).
Background
NeMoS efficiently applies convolution in a vectorized way. By default, tries to vectorize over all dimensions:
- The input channels (for example, the number of neurons in a spike count array).
- The number of filters (number of basis function usually).
This behavior is equivalent to the original implementation. This vectorization process allocates a large amount of memory (num channels x num filters). For example, convolving the counts of 200 neurons and 8 basis functions, would require vectorizing over (200 x 8) dimensions, causing the GPU to run out of memory even if the input and output of the convolution easily fit in memory.
What's New
In this PR, I introduced three optional parameters: batch_size_samples, batch_size_channels and batch_size_basis. This are integer values representing the batch size over which applying the vectorized convolution. In particular, when specified, the convolution will vectorize over (batch_size_channels x batch_size_basis) dimensions and apply the convolution in sequential fashion over chunks of the array of length batch_size_samples. The most memory conservative approach would be to set those parameters to 1, defaulting to applying one convolution at the time.
Example
Here is a script demonstrating the feature.
import os
# JAX memory management flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# Optional: Uncomment to make JAX deallocate GPU memory aggressively
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import time
import subprocess
import nemos as nmo
import jax
import numpy as np
def print_gpu_mem(tag=""):
output = subprocess.check_output([
'nvidia-smi', '--query-gpu=memory.used,memory.total',
'--format=csv,nounits,noheader'
])
used, total = map(int, output.decode().split(','))
print(f"[{tag}] GPU memory usage: {used} MiB / {total} MiB")
print("JAX version:", jax.__version__)
print("JAX device:", jax.devices()[0])
print("JAX device type:", jax.devices()[0].device_kind)
print("sys.platform:", os.sys.platform)
print("Nemos version:", nmo.__version__)
print("NumPy version:", np.__version__)
window_size = 20
basis = nmo.basis.RaisedCosineLogConv(
n_basis_funcs=8, window_size=window_size, label="count_history",
conv_kwargs={"batch_size_samples": 125_000,"batch_size_channels":1, "batch_size_basis": 1}
)
spike_matrix = np.random.randint(0, 2, size=(500_000, 200))
print_gpu_mem("Before compute_features")
start = time.time()
conv_spk = basis.compute_features(spike_matrix)
elapsed = time.time() - start
print_gpu_mem("After compute_features")
print(f"conv_spk shape: {conv_spk.shape}")
print(f"conv_spk nbytes: {conv_spk.nbytes / 1024**2:.2f} MiB")
print(f"Computation took {elapsed:.2f} seconds")
Which on my machine outputs,
(venv) [ebalzani@ccnlin052 generalized-linear-models]$ python _scripts/gpu_mem_allocation.py
JAX version: 0.5.0
JAX device: cuda:0
JAX device type: Quadro RTX 6000
sys.platform: linux
Nemos version: 0.2.3.dev190
NumPy version: 2.1.3
[Before compute_features] GPU memory usage: 1463 MiB / 24576 MiB
[After compute_features] GPU memory usage: 4527 MiB / 24576 MiB
conv_spk shape: (500000, 1600)
conv_spk nbytes: 3051.76 MiB
Computation took 2.60 seconds
@arturoptophys you can check this branch out and let me know how it goes.
Yes, works now !
However,
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
is still required
with setting it:
JAX version: 0.6.0
JAX device: cuda:0
JAX device type: NVIDIA RTX 4000 Ada Generation Laptop GPU
sys.platform: linux
Nemos version: 0.2.3
NumPy version: 1.26.4
[Before compute_features] GPU memory usage: 2285 MiB / 12282 MiB
2025-05-21 22:03:01.496869: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 5.63GiB (6044409844 bytes) by rematerialization; only reduced to 5.98GiB (6419757132 bytes), down from 5.98GiB (6419757132 bytes) originally
[After compute_features] GPU memory usage: 5349 MiB / 12282 MiB
conv_spk shape: (500000, 1600)
conv_spk nbytes: 3051.76 MiB
Computation took 1.79 seconds
without: 2025-05-21 22:05:41.966492: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 5.08GiB (5452891411 bytes) by rematerialization; only reduced to 5.98GiB (6419757132 bytes), down from 5.98GiB (6419757132 bytes) originally 2025-05-21 22:05:42.161530: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 2.60GiB (2792892015 bytes) by rematerialization; only reduced to 2.98GiB (3200000000 bytes), down from 2.98GiB (3200000000 bytes) originally 2025-05-21 22:05:42.215929: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 2.60GiB (2790882583 bytes) by rematerialization; only reduced to 2.98GiB (3200493599 bytes), down from 2.98GiB (3200493599 bytes) originally 2025-05-21 22:05:52.238461: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.98GiB (rounded to 3199993600)requested by op 2025-05-21 22:05:52.238531: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] _________________________________******************************___________________________ E0521 22:05:52.238543 23357 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3199993600 bytes. [tf-allocator-allocation-error=''] Traceback (most recent call last):
Ok, I will make it bullet proof by batching over time as well (so that in principle it would be possible to parallelise on the other axes and have an efficient convolution on the gpu). Good that it works already but i'll keep working on it
@arturoptophys try it now, I added the possibility of scan through time as well. I updated my comment to reflect the new parameter names. Scanning over time will be slightly slower (or quite slower for small batch sizes) but safer. Let me know if this work without the flags for deallocating etc.
Codecov Report
Attention: Patch coverage is 96.99248% with 4 lines in your changes missing coverage. Please review.
Project coverage is 70.36%. Comparing base (
f839f5c) to head (f96f332). Report is 52 commits behind head on development.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/nemos/validation.py | 89.47% | 4 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## development #348 +/- ##
================================================
- Coverage 96.72% 70.36% -26.37%
================================================
Files 39 90 +51
Lines 3636 8584 +4948
================================================
+ Hits 3517 6040 +2523
- Misses 119 2544 +2425
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
I added the test for both functions and edited the docs, let me know