timemachine icon indicating copy to clipboard operation
timemachine copied to clipboard

Constructing potential energy functions for many systems on same device

Open maxentile opened this issue 2 years ago • 5 comments

Context We would like to define a Jax function referencing the potential energy functions of many systems at once. (E.g. to define a differentiable loss function in terms of a batch of molecules.)

Problem The way the Jax wrappers are currently implemented requires to construct separate unbound impls for each system, which means (1) some start-up time required to instantiate all impls, and (2) a limit on the number of systems, imposed by the available GPU memory.

Reproduction The following loop to construct a potential energy function for each molecule in FreeSolv will take a couple seconds per iteration, then crash when out of GPU memory.

import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)

from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield

ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()


def prepare_energy_fxn(mol):
    ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
    U_fxn = construct_differentiable_interface_fast(ubps, params)
    
    return U_fxn, params


# crashes after a few dozen iterations / several minutes
energy_fxns = []
for mol in mols:
    energy_fxns.append(prepare_energy_fxn(mol))
    print(len(energy_fxns))

# ...
# def loss_fxn(ff_params):
#    # ...
#    # something that requires U_fxn(...) for (U_fxn, _) in energy_fxns
#    # ...
#    return loss
# ...
# _ = grad(loss_fxn)(ff_params)
# ...

(And a slightly loggier version of this loop that also queries GPU memory each iteration)

import subprocess as sp
import os

def get_gpu_memory(device=0):
    """adapted from https://stackoverflow.com/a/59571639"""
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values[device]

import numpy as np

import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)

from tqdm import tqdm

import time
from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield
from timemachine.lib.potentials import NonbondedInteractionGroup

ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()


def nb_ig_from_nb(nonbonded_ubp):
    lambda_plane_idxs = nonbonded_ubp.get_lambda_plane_idxs()
    lambda_offset_idxs = nonbonded_ubp.get_lambda_offset_idxs()
    beta = nonbonded_ubp.get_beta()
    cutoff = nonbonded_ubp.get_cutoff()
    
    ligand_idxs = np.array(np.where(lambda_offset_idxs != 0)[0], dtype=np.int32)

    # switched from Nonbonded to NonbondedInteractionGroup in hopes of reducing memory consumption
    nb_ig = NonbondedInteractionGroup(ligand_idxs, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
    
    return nb_ig


def prepare_energy_fxn(mol, use_interaction_group=True):
    ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
    n_atoms = len(coords)
    
    if use_interaction_group:
        ubps_prime = ubps[:-1] + [nb_ig_from_nb(ubps[-1])]
        U_fxn = construct_differentiable_interface_fast(ubps_prime, params)
    else:
        U_fxn = construct_differentiable_interface_fast(ubps, params)
    
    return (U_fxn, params, n_atoms)


# hmm, still crashes after several minutes
energy_fxns = []
n_atoms_traj = [0]

device_idx = 1

free_memory_traj = [get_gpu_memory(device_idx)]
use_interaction_group = False

for mol in mols:
    U_fxn, params, n_atoms = prepare_energy_fxn(mol, use_interaction_group=use_interaction_group)
    energy_fxns.append((U_fxn, params))
    n_atoms_traj.append(n_atoms)
    
    # wait a bit before querying nvidia-smi
    time.sleep(0.5)
    free_memory_traj.append(get_gpu_memory(device_idx))
    
    np.savez(f'memory_log_ig={use_interaction_group}', free_memory_traj=free_memory_traj, n_atoms_traj=n_atoms_traj)

Notes

  • In this loop the systems contain about 2500 atoms each.
  • @proteneer suggested to check whether the memory consumption per system is reduced by using the NonbondedInteractionGroup potential compared with the default Nonbonded potential -- in both cases the memory consumption appears to be ~64MB per system. image
  • Crash occurs in this loop when nvidia-smi reports ~1GB memory remaining

Possible solutions

  • Refactor the Jax wrappers so that multiple systems can use the same underlying unbound_impl? (May address both the startup cost and the memory limit)
  • Reduce the memory consumption of each (neighborlist?) impl? (May be more invasive, may address only the memory limit)
  • Use separate GPUs for separate systems?
  • Use Jax reference implementation in these cases?
  • ...

maxentile avatar Mar 29 '22 20:03 maxentile

Refactor the Jax wrappers so that multiple systems can use the same underlying unbound_impl?

Actually, this route would probably require a deeper change than just the Jax wrappers (setters potential.set_idxs(...) etc. are available for the potentials, but I think idxs etc. are constant for the lifetime of the impls impl = potential.unbound_impl(precision)...)

Reduce the memory consumption of each (neighborlist?) impl?

~Not sure how much room there is for reduction here -- memory usage is close to what you would expect for 2500 * 3 float64s ...~ Oof, off by a factor of a thousand: should expect 60 kilobytes not 60 megabytes...

maxentile avatar Mar 29 '22 21:03 maxentile

From looking into this with @proteneer and @jkausrelay : the source of the 64MB allocations is https://github.com/proteneer/timemachine/blob/136da6b377fd1cc29388d233f1d80d175eed83d1/timemachine/cpp/src/nonbonded_all_pairs.cu#L120-L137

Short term solution: Replace 256 with NUM_BINS = 96 (or 128), which should have no impact on correctness, should have modest impact on efficiency, and will reduce the allocation size from ~64MB to ~3 (or ~8) MB per system.

Longer term solution: Extract this Hilbert index into a separate object that can be reused by multiple nonbonded impls, rather than creating a duplicate for each nonbonded impl.

Side quest: Possibly re-run a benchmarking script on each of a grid of settings of NUM_BINS, to measure performance impact of this parameter.

maxentile avatar Mar 30 '22 16:03 maxentile

(Side note: This index is also duplicated between nonbonded_all_pairs.cu and nonbonded_interaction_group.cu, so factoring this out of both files could be related to https://github.com/proteneer/timemachine/issues/639 .)

maxentile avatar Mar 30 '22 16:03 maxentile

https://github.com/proteneer/timemachine/blob/master/timemachine/cpp/src/kernels/k_nonbonded.cu#L3 would also need to be updated, 256 is hardcoded here.

jkausrelay avatar Apr 04 '22 17:04 jkausrelay

good call!

proteneer avatar Apr 04 '22 18:04 proteneer

Was resolved by #692 .

(And the need to instantiate batches of GPU potentials for reweighting is avoided anyway when a linear basis function trick (#685, #931) is applicable)

maxentile avatar Mar 07 '23 18:03 maxentile