neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Training dynamics with a custom metric

Open RylanSchaeffer opened this issue 3 years ago • 18 comments

I'm having fun playing with the Neural Tangents Cookbook.ipynb and I'd like to try extending it to multivariate regression. However, when I changed the output dimension of last layer in stax.serial, the dimensions of the predicted mean and predicted covariance remain the same. Why is this, and what do I need to change to extend to multivariate regression?

RylanSchaeffer avatar Aug 05 '20 03:08 RylanSchaeffer

Hi Rylan, could it be due to not also extending the dimensionality of y_train by any chance? This would be my first guess because for the fully-connected network the infinite-width prior mean and covariance are the same for each output unit, due to i.i.d. weight and biases, and the respective covariance matrix returned by kernel_fn will have size N x N instead of N x N x output_dim x output_dim to save space, regardless of the size of the last layer.

In this case, if y_train has size N x 1 and not N x output_dim, nt.predict functions will give 1D predictions without raising an error.

See also an example of multivariate (10-class) regression here: https://github.com/google/neural-tangents/blob/master/examples/infinite_fcn.py (note that as discussed before, the output width at https://github.com/google/neural-tangents/blob/83fece767bccfe4e11872e8fc7128bb4d9325fa8/examples/infinite_fcn.py#L51 doesn't matter in the infinite width, and the width is actually inferred from y_train).

Lmk if this helps, and if not happy to take a look at your modified notebook if you'd like!

romanngg avatar Aug 05 '20 05:08 romanngg

Yep, that fixed the issue! I have a follow up question, if it's alright with you?

I want to understand how a particular metric affects learning. So instead of studying learning via gradient descent on MSE (y - yhat)^T (y - yhat), I want to study learning via gradient descent on a modified MSE (y - yhat)^T D (y - yhat), where D is my metric of interest.

I thought I could insert the metric inside loss_fn , but looking more closely, loss_fn relies on predict_fn and predict_fn is determined by gradient descent on MSE.

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs, train_ys, diag_reg=1e-4)

My question is: how can I alter predict_fn so that gradient descent is performed on the modified MSE loss?

RylanSchaeffer avatar Aug 05 '20 18:08 RylanSchaeffer

Hey Rylan,

I'm not totally sure, but let's see if we can work something out. I wrote up a short note on my interpretation of your problem here, let me know if the details are as you intended in your question. In any case, I think you can get the dynamics for this modified problem by just multiplying the NTK by the metric. If you agree with that, then I think the easiest way to proceed would be to wrap the kernel function so that it multiplies the analytic NTK by the metric. Schematically something like,

def metric_kernel_fn(x1, x2):
  kernel = kernel_fn(x1, x2)
  return kernel.replace(ntk=kernel.ntk * metric)

though you will need to do some finagling (aka broadcasting) to get the shapes to work out properly. Let me know if you get stuck on this and I'm happy to help iterate. I also haven't thought too much about the uncertainty prediction, it's possible you will also need to modify the nngp kernel to get that right.

Finally, note that the prediction functions include a trace_axis argument that indicates which axes are taken to be diagonal. Normally, as Roman mentioned above, the analytic NTK lacks a channel dimension because it is presumed to be diagonal along these dimensions. However, in your case the kernel will clearly have nontrivial channel dimensions induced by the metric. Therefore you probably need to set trace_axes=() for inference here.

Let me know if anything here is unclear or if I've misunderstood your question!

sschoenholz avatar Aug 06 '20 17:08 sschoenholz

Hi Sam! Thanks for taking the time to help me with my problem! Your note is correct, but I omitted a detail in my earlier comment because I was trying to keep things simple. In my case, the metric (your D, my T below) also depends on the example index (your i):

image

More specifically, the ith metric is time varying as a function of the sign of the prediction error f_i(t) - y_i. I don't think something like the suggested metric_kernel_fn will work because I need access to f_i(t) - y_i to first compute the corresponding metrics.

(Also, can I just say, your work is amazing. I attended a talk you gave at Uber's Science Symposium back when I worked at Uber and thought your explanations were top notch).

RylanSchaeffer avatar Aug 06 '20 18:08 RylanSchaeffer

Rewritten in vectorized notation:

image

RylanSchaeffer avatar Aug 06 '20 18:08 RylanSchaeffer

Just a quick comment - I'm not sure if time-dependent (index-only-dependent should be OK I think, along the lines of Sam's suggestion) metric will allow for a closed-form solution, in which case you may want to look at solving the ODE numerically https://neural-tangents.readthedocs.io/en/latest/neural_tangents.predict.html#neural_tangents.predict.gradient_descent by using the metric inside the loss argument.

romanngg avatar Aug 06 '20 18:08 romanngg

Technically, since the metric is also time-varying, there should be another term that emerges, but I don't know how to handle it so I'm ignoring it.

RylanSchaeffer avatar Aug 06 '20 18:08 RylanSchaeffer

Oops! @romanngg I only saw your comment when I refreshed the page. I'll take a look.

RylanSchaeffer avatar Aug 06 '20 19:08 RylanSchaeffer

I do want to see how far treating the metric as time-independent will carry me, but yes, I may need to consider a numerical solution instead.

RylanSchaeffer avatar Aug 06 '20 19:08 RylanSchaeffer

@romanngg do you want me to leave this issue open or close it?

RylanSchaeffer avatar Sep 03 '20 02:09 RylanSchaeffer

@romanngg , before I close this, I'm trying to use neural_tangents.predict.gradient_descent() but I'm running into a spot of bother.

When I call

predict_fn = nt.predict.gradient_descent(
    loss=expectile_regression_loss,
    k_train_train=k_train_train,
    y_train=train_ys)

I get an error inside gradient_descent() when the function _, odd, _, _ = _get_axes(k_train_train) is called:

'Kernel' object has no attribute 'ndim'

The documentation doesn't specify what a valid k_train_train should look like. Mine looks like

x = {Kernel} Kernel(nngp=DeviceArray([[ 1.3201054 ,  1.3133823 , -1.2921143 , -0.6002171 ,\n               1.2353275 ],\n             [ 1.3133823 ,  1.3087238 , -1.2892499 , -0.60428584,\n               1.2391908 ],\n             [-1.2921143 , -1.2892499 ,  1.2956082 ,  0.
 batch_axis = {int} 0
 channel_axis = {int} 1
 cov1 = {DeviceArray: 5} [1.3201054  1.3087238  1.2956082  0.43115377 1.2212754 ]
 cov2 = {DeviceArray: 5} [1.3201054  1.3087238  1.2956082  0.43115377 1.2212754 ]
 diagonal_batch = {bool} True
 diagonal_spatial = {bool} False
 is_gaussian = {bool} True
 is_input = {bool} False
 is_reversed = {bool} False
 mask1 = {NoneType} None
 mask2 = {NoneType} None
 nngp = {DeviceArray: 5} [[ 1.3201054   1.3133823  -1.2921143  -0.6002171   1.2353275 ]\n [ 1.3133823   1.3087238  -1.2892499  -0.60428584  1.2391908 ]\n [-1.2921143  -1.2892499   1.2956082   0.62317175 -1.2284737 ]\n [-0.6002171  -0.60428584  0.62317175  0.43115377 -0.62148845]\n [ 1.2353275   1.2391908  -1.2284737  -0.62148845  1.2212754 ]]
 ntk = {DeviceArray: 5} [[ 9.556314   8.988411  -8.319939  -1.9175906  6.278597 ]\n [ 8.988411   8.562843  -8.015796  -1.9313575  6.2393656]\n [-8.319939  -8.015796   7.7513213  1.9781551 -6.103701 ]\n [-1.9175906 -1.9313575  1.9781551  1.3127701 -1.9876453]\n [ 6.278597   6.2393656 -6.103701  -1.9876453  5.5840573]]
 shape1 = {tuple: 2} (5, 11)
 shape2 = {tuple: 2} (5, 11)
 x1_is_x2 = {bool} True

I construct my training-training kernel via:

train_points = 5

target_fn = lambda x: np.sin(x)
key, x_key, y_key = random.split(key, 3)
train_xs = random.uniform(x_key, (train_points, 1), minval=-np.pi, maxval=np.pi)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(num_expectiles, W_std=1.5, b_std=0.05)
)

k_train_train = kernel_fn(train_xs, train_xs)

What am I doing wrong such that my kernel is missing the desired attribute ndim?

RylanSchaeffer avatar Sep 03 '20 03:09 RylanSchaeffer

My guess is that k_train_train = kernel_fn(train_xs, train_xs) returns both the NNGP and NTK whereas nt.predict.gradient_descent() only wants one of the two as the kernel. Is this correct? It certainly causes the error to disappear, but I want to check my understanding.

RylanSchaeffer avatar Sep 03 '20 03:09 RylanSchaeffer

Hey Rylan, you are correct that it requires an array as input, we try to document this with type annotations (and not docstrings) as e.g. here https://neural-tangents.readthedocs.io/en/latest/neural_tangents.predict.html#neural_tangents.predict.gradient_descent (k_train_train: np.ndarray etc).

You can get the necessary matrix by either kernel_fn(train_xs, train_xs, 'ntk') or kernel_fn(train_xs, train_xs).ntk.

Re the issue, feel free to leave it open since we don't explicitly support this function in NT yet (but feel free to also close if you don't need this feature anymore)!

romanngg avatar Sep 03 '20 03:09 romanngg

@romanngg thank you for confirming! I'll leave the issue open. I have one last question (which isn't related, but since I'm here) - how do I draw initial parameters for a stax.serial to use with a kernel function? When I try

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(num_expectiles, W_std=1.5, b_std=0.05)
)

key, net_key = random.split(key)
output_shape, params = init_fn(net_key, (-1, 1))
k_train_train = kernel_fn(train_xs, train_xs, params).ntk

I get {AttributeError}'tuple' object has no attribute 'lower'. params is a list of tuples of DeviceArrays, so why is stax trying to lower a string?

RylanSchaeffer avatar Sep 03 '20 04:09 RylanSchaeffer

The documentation's example doesn't show how to generate a valid params (https://neural-tangents.readthedocs.io/en/latest/neural_tangents.predict.html#neural_tangents.predict.gradient_descent), so I used the code in the Neural Tangents Cookbook but I appear to be misusing that code.

RylanSchaeffer avatar Sep 03 '20 04:09 RylanSchaeffer

Oddly, fx_train_0 = apply_fn(params, train_xs) works and returns a DeviceArray with the correct shape. So why is params failing when used with kernel_fn()?

RylanSchaeffer avatar Sep 03 '20 04:09 RylanSchaeffer

The kernel_fn returned by stax.py functions is the closed-form, infinite-width kernel method, that doesn't need parameters (or rather, it has an infinite number of parameters, and their distribution is specified when you construct your stax network, so you don't need to pass anything to it). apply_fn is the finite-width forward prop function, so it does require parameters.

If you want an empirical kernel with given params, i.e. the outer product of finite-width outputs / jacobians, look into https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html (and perhaps https://neural-tangents.readthedocs.io/en/latest/neural_tangents.monte_carlo.html if you want many samples)

AFAIK we don't have examples there indeed, but roughly your code would go like

kernel_fn_empirical = nt.empirical.empirical_kernel_fn(f)  # f(params, inputs) is any function, not necessarily built with stax
nngp = kernel_fn_empirical(x1, x2, 'nngp', params)
ntk = kernel_fn_empirical(x1, x2, 'ntk', params)

romanngg avatar Sep 03 '20 04:09 romanngg

Brilliant. Thanks for the clarification!

RylanSchaeffer avatar Sep 03 '20 04:09 RylanSchaeffer