nemos icon indicating copy to clipboard operation
nemos copied to clipboard

Start facilitating coefficient splitting

Open BalzaniEdoardo opened this issue 1 year ago • 1 comments

Summary

This PR starts the process of facilitating slicing the feature axis of either the model coefficients, or of the design matrix.

In particular it introduces in Basis:

  • label parameter: name the variable, write only. Label of AdditiveBasis and MultiplicativeBasis combines the labels of 1d basis.
  • n_basis_input parameter: name the variable, write only, the shape of axis 1 of a 2D input to basis. This happens in convolve, for example when we convolve the counts of a neural population.
  • _get_feature_slicing: a recursion that returns a dictionary of slices, that can be applied to the feature axis. Labels (or combination of labels) are used as keys.
  • split_feature_axis: the user facing method that splits an array into feature components.
  • num_output_features: a read-only property that returns the number of output feature dimension.

_get_feature_slicing methods are for internal use, but will make our life very easy. In split_feature_axis we could use a jax.tree_utils.tree_map to apply the slicing to any array, and get automatically a dictionary, well labeled in a meaningful way, containing the coefficients.

If arrays are numpy, the use of slice is very efficient, since it will create a dict of views of the array.

It will also facilitate creating a "FeaturePytree" from an input in array or TsdFrame form.

Example of _get_feature_slicing:

>>> import nemos as nmo
>>> import jax

>>> bas1 = nmo.basis.RaisedCosineBasisLinear(3, mode="conv", n_basis_input=2, window_size=5, label="position")

>>> bas2 = nmo.basis.MSplineBasis(4, mode="conv", n_basis_input=3, window_size=5,  label="velocity")

>>> bas3 = bas1 + bas2 + bas1 * bas2

>>> # slice each individual input, default behavior.
>>> slice_dict = bas3._get_feature_slicing()[0]  
>>> slice_dict
{'position': {'0': slice(0, 3, None), '1': slice(3, 6, None)},
 'velocity': {'0': slice(6, 10, None),
  '1': slice(10, 14, None),
  '2': slice(14, 18, None)},
 '(position * velocity)': slice(18, 90, None)}

>>> # slice each additive component instead.
>>> bas3._get_feature_slicing(split_by_input=False)[0] 
{'position': slice(0, 6, None),
 'velocity': slice(6, 18, None),
 '(position * velocity)': slice(18, 90, None)}

>>> # splitting a design matrix becomes trivial
>>> x1 = np.random.normal(size=(10, 2))
>>> x2 = np.random.normal(size=(10, 3))
>>> X = bas3.compute_features(x1, x2, x1, x2)
>>> splits = jax.tree_util.tree_map(lambda sl: X[:, sl], slice_dict)

Example of split_feature_axis

>>> import numpy as np
>>> from nemos.basis import BSplineBasis
>>> from nemos.glm import GLM

>>> # Define an additive basis
>>> basis = (
...     BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") +
...     BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2")
... )

>>> # split an arbitrarily shaped array
>>> array = np.ones((1, 1, basis.num_output_features, 1))
>>> basis.split_feature_axis(array, axis=2)
{'feature_1': array([[[[1.],
          [1.],
          [1.],
          [1.],
          [1.]]]]),
 'feature_2': array([[[[1.],
          [1.],
          [1.],
          [1.],
          [1.],
          [1.]]]])}

>>> # example of usage in combination with GLM
>>> X = np.random.normal(size=(20, basis.num_output_features))
>>> y = np.random.poisson(size=(20, ))
>>> basis.split_feature_axis(GLM().fit(X, y).coef_, axis=0)
{'feature_1': Array([-0.02247754,  0.49239248, -0.09706223, -0.30416837,  0.04843776],      dtype=float32),
 'feature_2': Array([-0.29889402, -0.0040512 , -0.28740323,  0.5222396 ,  0.55201346,
        -0.13157026], dtype=float32)}

[NOTE] The use of parentheses guarantees that the label fully specifies the order of operation in a composite basis.

BalzaniEdoardo avatar Oct 09 '24 22:10 BalzaniEdoardo

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 97.27%. Comparing base (81b9dee) to head (bf87623). Report is 46 commits behind head on development.

Additional details and impacted files
@@               Coverage Diff               @@
##           development     #247      +/-   ##
===============================================
+ Coverage        96.73%   97.27%   +0.53%     
===============================================
  Files               25       25              
  Lines             2206     2199       -7     
===============================================
+ Hits              2134     2139       +5     
+ Misses              72       60      -12     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Oct 10 '24 16:10 codecov-commenter

I'm confused by the behavior of split_by_feature:

* Why does it return 3d arrays? If I look at the example in the docstring, it returns arrays of shape (20,1,5) and (20,1,6). What's the middle dimension? This is also the behavior for a single basis or multiplicative basis.

Basis in conv mode allow for mutlidimensional inputs (intended for counts which are a TsdFrame in pynapple). The extra dimension. Example below,

>>> import nemos as nmo
>>> import pynapple as nap
>>> import numpy as np
>>> counts = nap.TsdFrame(t=np.arange(100), d=np.random.poisson(size=(100, 2)))
>>> basis = nmo.basis.BSplineBasis(5, mode="conv", window_size=50)
>>> X = basis.compute_features(counts)
>>> X.shape
(100, 10)
>>> basis.split_by_feature(X, axis=1)["BSplineBasis"].shape
(100, 2, 5)

* This method only really make sense for the additive basis right? Otherwise it doesn't do anything but add the middle dimension? In that case, should it exist for the other bases objects? Or am I missing something?

I prefer that any basis should have the method for consistency. I can see a case in which one create the additive basis in a script, but the number of component is variable, from 1 to N, if you have the method, you can use the same exact code to process a regular basis ( 1 component) and the rest.

Secondly, reshaping correctly by input is handy. Without the method, in my python example above how do you split the feature axis correctly? X.reshape(100, 2, 5) or X.reshape(100, 5, 2)? that depends on the internals of the convolution, but one doesn't need to memorize.

I also don't think the tutorial we have on identifiability constraints right now is sufficient. Might not need to be addressed here, but should be added to the docs project if not. Should explain why it's bad to be rank-deficient, show what you gain from being full rank, etc.

Yes, totally

BalzaniEdoardo avatar Nov 06 '24 19:11 BalzaniEdoardo