pyscf-ipu
pyscf-ipu copied to clipboard
DFT iteration takes 5s but profile only ~100ms
The structure optimization in nanoDFT.py takes 5s wallclock each iteration (the call jitted_nanoDFT(*tensors) ), but the popvision profile is ~100ms.
Goal. Figure out what takes the 4.9s and remove it.
Things I'd try:
- put a print inside _nanoDFT to see if it traces multiple times.
- move jax.jit(partial(_nanoDFT ...)` outside and pass the jitted function as input to nanoDFT
- run with
XLA_VISIBLE_DEVICES=1 python ...
(hunch time is spent loading executables to many chips)
The problem seems to be related to the use of partial
(docs here).
Caching the jitted nanoDFT function works faster (ugly code to prove it):
cached_jitted_nanoDFT = None
def nanoDFT(mol, opts):
global cached_jitted_nanoDFT
# Init DFT tensors on CPU using PySCF.
tensors, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords = init_dft_tensors_cpu(mol, opts)
# Run DFT algorithm (can be hardware accelerated).
if cached_jitted_nanoDFT is None:
cached_jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
jitted_nanoDFT = cached_jitted_nanoDFT
vals = jitted_nanoDFT(*tensors)
logged_matrices, H_core, logged_energies = [np.asarray(a).astype(np.float64) for a in vals] # Ensure CPU
# It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices.
logged_E_xc = logged_energies[:, 4].copy()
density_matrices, Js, Ks, H = [logged_matrices[:, i] for i in range(4)]
energies, hlgaps = np.zeros((opts.its, 6)), np.zeros(opts.its)
for i in range(opts.its):
energies[i] = energy(density_matrices[i], H_core, Js[i], Ks[i], logged_E_xc[i], E_nuc, np)
hlgaps[i] = hlgap(L_inv, H[i], n_electrons_half, np)
energies, logged_energies, hlgaps = [a * HARTREE_TO_EV for a in [energies, logged_energies, hlgaps]]
mo_energy, mo_coeff = np.linalg.eigh(L_inv @ H[-1] @ L_inv.T)
mo_coeff = L_inv.T @ mo_coeff
return energies, (logged_energies, hlgaps, mo_energy, mo_coeff, grid_coords, grid_weights)
Tried caching mol instead, but it doesn't seem to make a difference - haven't tried caching all arguments though.
Hypothesis: nanoDFT(mol, opt) needs access to mol while Jax traces nanoDFT. Jax allows this with static_argnums
by using hashing to check for recompilation. Because mol/opt doesn't support hash we instead use partial to create a new function (which causes the problem of this thread).
One potential fix may be to add a custom hash to mol/opt which only changes when we need to recompile (i.e. when mol.nao_nr()
changes, so mol.hash=mol.nao_nr
). @balancap Is this too ugly?
Based on what I could find, one way is to derive the class, for example:
class HashableMole(pyscf.gto.mole.Mole):
def __init__(self):
pyscf.gto.mole.Mole.__init__(self)
print('Hello HashableMole')
def _tree_flatten(self):
children = () # arrays / dynamic values
aux_data = {} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
def __hash__(self):
print('Hello __hash__!')
return hash(self.nao_nr()) # based on previous suggestion
def __eq__(self, other):
print('Hello __eq__')
return (isinstance(other, HashableMole) and self.nao_nr() == other.nao_nr()) # based on previous suggestion
from jax import tree_util
tree_util.register_pytree_node(HashableMole,
HashableMole._tree_flatten,
HashableMole._tree_unflatten)
and initialise mol = HashableMole()
in build_mol
, and then use:
jitted_nanoDFT = jax.jit(_nanoDFT, backend=opts.backend, static_argnames=['mol', 'opts'])
vals = jitted_nanoDFT(mol=mol, opts=opts, *tensors)
Doing this, it looks like it also complains about opts
not being hashable because it's not immutable (e.g. opts.basis = "6-31G"), so wrapping it in a namedtuple
does the trick.
Hope this helps?
Neat! Let's revise and make it into a PR :)
Q1. What happens if we remove all the jax tree stuff? (I might just be misunderstanding, but don't see why it's needed in this case)
Q2. Could we be a bit cheeky and do mol.__hash__ = self.nao_nr; mol.__eq__ = lambda self, other: ...
?
Perhaps the wrong issue to attach this to, but seeing as you're in there anyway - I have just been looking at what's going on to see what/how this could demonstrate that IPUs are faster, and can see the following (all this FYI so pls ignore if you know this already!): On both IPU and CPU, the time for jax.jit is instantaneous... i.e. at least one go-round of jitted_nanoDFT itself is required to get anything to compile/happen. Then on IPU we see that: The very first time (i.e. when a compilation is needed), it takes 149 seconds. When loading from the cache, it takes 7.5 seconds But when subsequently used in the same program, it takes 16ms
CPU comparisons are that the first time it takes 3.5 seconds, subsequent pass 933ms.
# Run DFT algorithm (can be hardware accelerated).
start = time.perf_counter_ns()
jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
jit_complete = time.perf_counter_ns()
vals = jitted_nanoDFT(*tensors)
computation_complete = time.perf_counter_ns()
vals = jitted_nanoDFT(*tensors)
computation_complete2 = time.perf_counter_ns()
logged_matrices, H_core, logged_energies = [np.asarray(a).astype(np.float64) for a in vals] # Ensure CPU
## No need to print this is it is instantaneous
## print(f"JIT duration was {(jit_complete - start) // 1000000}ms.")
print(f"1st pass duration was {(computation_complete - jit_complete) // 1000000}ms.")
print(f"2nd pass duration was {(computation_complete2 - computation_complete) // 1000000}ms.")
To "boot it up" , it needs to be run once... so it isn't clear how we could (or should?) use this sample to demonstrate that an IPU is faster than a CPU, as we'd need to be doing more than one iteration in the same program to show that.
Thanks for taking the time to dive into the code Anthony! :)
My take: We should change default to an interesting use-case that has 10^3 iterations (e.g. like this video, each frame is one iteration, so do a movie with 10^3 frames).
Fun note: With 16ms we can do ~60 FPS on IPU compared to ~1 FPS on CPU. We could make a "side-by-side" comparison of frame-rate using different CPU/IPU backends.
@AlexanderMath
Q1. What happens if we remove all the jax tree stuff? (I might just be misunderstanding, but don't see why it's needed in this case)
In fact, nothing happens, it works without those methods and without registering it. It was a misunderstanding on my side, I thought registering it as a PyTree was necessary (docs link) to be handled correctly by jit, but it is not.
Thanks, this simplifies the code!
Nice! Rule of thumb in JAX: registering custom class with the PyTree mechanism should really be exceptional. 99% of the time, we should be fine be passing tuple, namedtuple, dataclass (using chex
lib), dict, ...
@AlexanderMath
Q2. Could we be a bit cheeky and do mol.hash = self.nao_nr; mol.eq = lambda self, other: ...?
This seems to work:
pyscf.gto.mole.Mole.__hash__ = lambda self: hash(self.nao_nr())
pyscf.gto.mole.Mole.__eq__ = lambda self, other: self.nao_nr() == other.nao_nr()
EDIT (I thought we need setattr but the above works, so leaving both here): we could use setattr
too, for example:
setattr(pyscf.gto.mole.Mole, "__hash__", lambda self: hash(self.nao_nr()))
setattr(pyscf.gto.mole.Mole, "__eq__", lambda self, other: self.nao_nr() == other.nao_nr())
If you are happy with this strategy, it would be helpful to know:
- whether we should keep these globally in the file, or maybe add them in build_mol to make it a bit clearer what we're doing?
- whether it's fine to use just
nao_nr
or whether we want a list of attributes to count which potentially makes more sense from a domain expert's point of view?
Other options considered:
- refactoring mol out of the code, but this seems a bit drastic as it's not just a few occurrences
- use the derived class, but it is probably too complicated for what we need
Thanks!