nemos
nemos copied to clipboard
Start facilitating coefficient splitting
Summary
This PR starts the process of facilitating slicing the feature axis of either the model coefficients, or of the design matrix.
In particular it introduces in Basis:
labelparameter: name the variable, write only. Label ofAdditiveBasisandMultiplicativeBasiscombines the labels of 1d basis.n_basis_inputparameter: name the variable, write only, the shape of axis 1 of a 2D input to basis. This happens in convolve, for example when we convolve the counts of a neural population.- _get_feature_slicing: a recursion that returns a dictionary of slices, that can be applied to the feature axis. Labels (or combination of labels) are used as keys.
- split_feature_axis: the user facing method that splits an array into feature components.
- num_output_features: a read-only property that returns the number of output feature dimension.
_get_feature_slicing methods are for internal use, but will make our life very easy. In split_feature_axis we could use a jax.tree_utils.tree_map to apply the slicing to any array, and get automatically a dictionary, well labeled in a meaningful way, containing the coefficients.
If arrays are numpy, the use of slice is very efficient, since it will create a dict of views of the array.
It will also facilitate creating a "FeaturePytree" from an input in array or TsdFrame form.
Example of _get_feature_slicing:
>>> import nemos as nmo
>>> import jax
>>> bas1 = nmo.basis.RaisedCosineBasisLinear(3, mode="conv", n_basis_input=2, window_size=5, label="position")
>>> bas2 = nmo.basis.MSplineBasis(4, mode="conv", n_basis_input=3, window_size=5, label="velocity")
>>> bas3 = bas1 + bas2 + bas1 * bas2
>>> # slice each individual input, default behavior.
>>> slice_dict = bas3._get_feature_slicing()[0]
>>> slice_dict
{'position': {'0': slice(0, 3, None), '1': slice(3, 6, None)},
'velocity': {'0': slice(6, 10, None),
'1': slice(10, 14, None),
'2': slice(14, 18, None)},
'(position * velocity)': slice(18, 90, None)}
>>> # slice each additive component instead.
>>> bas3._get_feature_slicing(split_by_input=False)[0]
{'position': slice(0, 6, None),
'velocity': slice(6, 18, None),
'(position * velocity)': slice(18, 90, None)}
>>> # splitting a design matrix becomes trivial
>>> x1 = np.random.normal(size=(10, 2))
>>> x2 = np.random.normal(size=(10, 3))
>>> X = bas3.compute_features(x1, x2, x1, x2)
>>> splits = jax.tree_util.tree_map(lambda sl: X[:, sl], slice_dict)
Example of split_feature_axis
>>> import numpy as np
>>> from nemos.basis import BSplineBasis
>>> from nemos.glm import GLM
>>> # Define an additive basis
>>> basis = (
... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") +
... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2")
... )
>>> # split an arbitrarily shaped array
>>> array = np.ones((1, 1, basis.num_output_features, 1))
>>> basis.split_feature_axis(array, axis=2)
{'feature_1': array([[[[1.],
[1.],
[1.],
[1.],
[1.]]]]),
'feature_2': array([[[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]]]])}
>>> # example of usage in combination with GLM
>>> X = np.random.normal(size=(20, basis.num_output_features))
>>> y = np.random.poisson(size=(20, ))
>>> basis.split_feature_axis(GLM().fit(X, y).coef_, axis=0)
{'feature_1': Array([-0.02247754, 0.49239248, -0.09706223, -0.30416837, 0.04843776], dtype=float32),
'feature_2': Array([-0.29889402, -0.0040512 , -0.28740323, 0.5222396 , 0.55201346,
-0.13157026], dtype=float32)}
[NOTE] The use of parentheses guarantees that the
labelfully specifies the order of operation in a composite basis.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 97.27%. Comparing base (
81b9dee) to head (bf87623). Report is 46 commits behind head on development.
Additional details and impacted files
@@ Coverage Diff @@
## development #247 +/- ##
===============================================
+ Coverage 96.73% 97.27% +0.53%
===============================================
Files 25 25
Lines 2206 2199 -7
===============================================
+ Hits 2134 2139 +5
+ Misses 72 60 -12
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
I'm confused by the behavior of
split_by_feature:* Why does it return 3d arrays? If I look at the example in the docstring, it returns arrays of shape (20,1,5) and (20,1,6). What's the middle dimension? This is also the behavior for a single basis or multiplicative basis.
Basis in conv mode allow for mutlidimensional inputs (intended for counts which are a TsdFrame in pynapple). The extra dimension. Example below,
>>> import nemos as nmo
>>> import pynapple as nap
>>> import numpy as np
>>> counts = nap.TsdFrame(t=np.arange(100), d=np.random.poisson(size=(100, 2)))
>>> basis = nmo.basis.BSplineBasis(5, mode="conv", window_size=50)
>>> X = basis.compute_features(counts)
>>> X.shape
(100, 10)
>>> basis.split_by_feature(X, axis=1)["BSplineBasis"].shape
(100, 2, 5)
* This method only really make sense for the additive basis right? Otherwise it doesn't do anything but add the middle dimension? In that case, should it exist for the other bases objects? Or am I missing something?
I prefer that any basis should have the method for consistency. I can see a case in which one create the additive basis in a script, but the number of component is variable, from 1 to N, if you have the method, you can use the same exact code to process a regular basis ( 1 component) and the rest.
Secondly, reshaping correctly by input is handy. Without the method, in my python example above how do you split the feature axis correctly? X.reshape(100, 2, 5) or X.reshape(100, 5, 2)? that depends on the internals of the convolution, but one doesn't need to memorize.
I also don't think the tutorial we have on identifiability constraints right now is sufficient. Might not need to be addressed here, but should be added to the docs project if not. Should explain why it's bad to be rank-deficient, show what you gain from being full rank, etc.
Yes, totally