equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Question] Tool suggestions for visualizing equinox module dynamics

Open smorad opened this issue 9 months ago • 2 comments

Are there any tools you would suggest for interpreting the dynamics of deep/complex equinox modules? Perhaps something that tracks the mean and standard deviation of the weights, or the norm of the activations? Maybe how the gradient is distributed between parameters? I suppose one could do something with jax.tree_util.tree_leaves but I wanted to see if there were any existing tools you could suggest.

smorad avatar Nov 18 '23 14:11 smorad

I don't know of such a package, I'm afraid. Maybe an opportunity to write one? :) If it ends up being general enough then we could think about putting it in eqx.nn, or keeping it as a standalone project and pointing users towards it.

patrick-kidger avatar Nov 18 '23 19:11 patrick-kidger

@smorad fwiw, I did this kind of thing in some util I made once.

import jax
import numpy as np


def get_summary_info(model):
    """An alternative repr useful for initial debugging"""
    import pandas as pd

    def get_info(v):
        info = dict()
        info['type'] = type(v).__name__
        if isinstance(v, (jax.Array, np.ndarray, float)):
            info['dtype'] = v.dtype.name if hasattr(v, 'dtype') else None
            info['shape'] = np.shape(v)
            info['size'] = np.size(v)
            info['nancount'] = np.isnan(v).sum()
            info['zerocount'] = np.size(v) - np.count_nonzero(v)
            info['min'] = np.min(v).item()
            info['max'] = np.max(v).item()
        return info

    d_ = {jax.tree_util.keystr(k): get_info(v) for k, v in jax.tree_util.tree_leaves_with_path(model)}
    return pd.DataFrame(d_).T

cottrell avatar Nov 20 '23 09:11 cottrell