Mohamad Amin Mohamadi
Mohamad Amin Mohamadi
Hey! Thanks for all the amazing work. I'm trying to compute the NTK for some data on a shallow version of WideResNet, and I'm encountering a non-PSD matrix which results...
I was just reading through the file as I wanted to apply some modifications and I saw this function: https://github.com/google/neural-tangents/blob/3deb1972d7299bff63659c7ff62910cf0fc56cdf/neural_tangents/utils/batch.py#L126 And this comment: > """Implements an unrolled version of scan....
Hey! I was curious about why you choose this as the batch size: https://github.com/google/neural-tangents/blob/9f2ebc88905c46d60b7c4a9da25636924acc9d45/neural_tangents/utils/batch.py#L611 I'm not perfectly familiar with the kernel computation algorithm, but according to this, is the kernel...
Hey! I was reading through the code and I noticed that you're using element-wise exponential matrix here: https://github.com/google/neural-tangents/blob/5f286b7696364217aa4a2d92378aabd0203a791e/neural_tangents/predict.py#L1180 Does this correspond to this equation (9) in the paper https://arxiv.org/pdf/1902.06720? If...
This is not a bug, maybe a suggestion. Is there any plan on integrating [Architecture Components](https://developer.android.com/topic/libraries/architecture/index.html) with this library? I have some suggestions like: * Having some modules to integrate...
Hey, thanks for the great work! I'm using BatchNorm in my network, but have set the `use_running_average` parameter of BatchNorm layers to true, which means it will not compute any...
Hey, I would like to calculate the mentioned jacobians. Right now I'm trying this: ```python func, params, buffers = make_functional_with_buffers(model) J = jacrev(lambda p: func(p, buffers, input_dict))(params) ``` But this...
Hello, thanks for the great work. In some cases, I'd like to load the dataset after some transformations and do some analysis on the datapoint before feeding them to the...
Hey! I'm trying to compute the result of multiple kernel ridge regressions in a parallel mode. I've wrote the code and created jax expressions of my functions using `jax.make_jaxpr`. According...
Hello, When I try to compute the NTK of a model with an embedding layer, I get the following warning: ``` /usr/local/lib/python3.10/dist-packages/neural_tangents/_src/empirical.py:2215: UserWarning: No Jacobian rule found for gather. warnings.warn(f'No...