Roman Novak

Results 80 comments of Roman Novak

`nt.linearize` is essentially a Jacobian-vector product (`jax.jvp`), and it's peak memory consumption of the linearized forward pass should be about 2x the peak memory consumption of the forward pass. Then,...

Thanks for bringing this to our attention - this definitely looks like a bug, I think we are doing an implicit assumption of being in the overparameterized regime and just...

IIUC you'd want to pass `t=None`, which we treat as symbolic infinity and choose a code path that does the simplified expression like Eq 16 in https://arxiv.org/pdf/1902.06720.pdf

I believe the answer is yes, but for the linearization of the the neural network function (not the original neural network function; the wider the network is, the closer they...

As far as I understand, for specific finite width `n`, there will always be a mismatch between your network and its linearization, and we don't have super practical bounds on...

Your code sample looks good to me for finite-width masking (`init_fn`, `apply_fn` should work), but in the infinite width (`kernel_fn`) activations along the `channel_axis` are considered infinite, so you can't...

Sorry for the delay - I think this should definitely be possible in principle, but indeed we haven't implemented this yet + it really depends on your exact interpretation of...

Could you clarify the shape of `model(w, train_x)`, is it `(batch_size, output_size)`? In this case, If you specify `trace_axes=()`, the empirical NTK is `(batch_size, batch_size, output_size, output_size)`. If you specify...

I'm admittedly not familiar with `pprof`. To double check, is the issue only present when you use Neural Tangents, or when profiling other codebases as well? If the latter, it...

+1 to Sam re padding, and also note that even unpadded, the intermediary NNGP covariance of shape 10x10x(64x64)x(64x64) is 6.25 Gb. To propagate this tensor through the NNGP computation from...